| @@ -162,6 +162,13 @@ namespace Tensorflow.Contexts | |||||
| return c_api.TFE_ContextHasFunction(_handle, name); | return c_api.TFE_ContextHasFunction(_handle, name); | ||||
| } | } | ||||
| public void add_function_def(FunctionDef fdef) | |||||
| { | |||||
| ensure_initialized(); | |||||
| var fdef_string = fdef.ToString(); | |||||
| c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length); | |||||
| } | |||||
| public void restore_mode() | public void restore_mode() | ||||
| { | { | ||||
| context_switches.Pop(); | context_switches.Pop(); | ||||
| @@ -358,7 +358,7 @@ namespace Tensorflow.Eager | |||||
| break; | break; | ||||
| case TF_AttrType.TF_ATTR_FUNC: | case TF_AttrType.TF_ATTR_FUNC: | ||||
| if (value is ConcreteFunction func) | if (value is ConcreteFunction func) | ||||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length); | |||||
| c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length); | |||||
| else | else | ||||
| throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC"); | ||||
| break; | break; | ||||
| @@ -30,6 +30,9 @@ namespace Tensorflow | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status); | ||||
| [DllImport(TensorFlowLibName)] | |||||
| public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size); | |||||
| [DllImport(TensorFlowLibName)] | [DllImport(TensorFlowLibName)] | ||||
| public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy); | ||||
| @@ -111,7 +111,17 @@ namespace Tensorflow | |||||
| public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions(); | ||||
| public static Buffer tf_buffer(byte[] data) => new Buffer(data); | |||||
| public static Buffer tf_buffer(byte[] data = null) | |||||
| { | |||||
| if(data is not null) | |||||
| { | |||||
| return new Buffer(data); ; | |||||
| } | |||||
| else | |||||
| { | |||||
| return new Buffer(); | |||||
| } | |||||
| } | |||||
| public static IEnumerable<Operation> new_tf_operations(Graph graph) | public static IEnumerable<Operation> new_tf_operations(Graph graph) | ||||
| { | { | ||||
| @@ -15,11 +15,13 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| protected IEnumerable<Tensor> _captured_inputs; | protected IEnumerable<Tensor> _captured_inputs; | ||||
| internal FuncGraph func_graph; | internal FuncGraph func_graph; | ||||
| protected DelayedRewriteGradientFunctions _delayed_rewrite_functions; | |||||
| protected Dictionary<string, string> _attrs; | |||||
| internal ForwardBackwardCall forward_backward; | internal ForwardBackwardCall forward_backward; | ||||
| public Tensor[] Inputs => func_graph.Inputs; | public Tensor[] Inputs => func_graph.Inputs; | ||||
| public Tensor[] CapturedInputs => func_graph.external_captures; | public Tensor[] CapturedInputs => func_graph.external_captures; | ||||
| public string Name => func_graph?.FuncName; | |||||
| public string Name => _delayed_rewrite_functions.forward().Name; | |||||
| public Tensor[] Outputs; | public Tensor[] Outputs; | ||||
| public Type ReturnType; | public Type ReturnType; | ||||
| @@ -31,6 +33,8 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| func_graph = new FuncGraph(name); | func_graph = new FuncGraph(name); | ||||
| _captured_inputs = func_graph.external_captures; | _captured_inputs = func_graph.external_captures; | ||||
| _attrs= new Dictionary<string, string>(); | |||||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||||
| } | } | ||||
| public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null) | ||||
| @@ -38,7 +42,9 @@ namespace Tensorflow.Functions | |||||
| func_graph = graph; | func_graph = graph; | ||||
| _captured_inputs = func_graph.external_captures; | _captured_inputs = func_graph.external_captures; | ||||
| ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||||
| //ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray()); | |||||
| _attrs = attrs; | |||||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype) | ||||
| @@ -57,6 +63,8 @@ namespace Tensorflow.Functions | |||||
| null); | null); | ||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| _captured_inputs = func_graph.external_captures; | _captured_inputs = func_graph.external_captures; | ||||
| _attrs = new Dictionary<string, string>(); | |||||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||||
| } | } | ||||
| public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype) | ||||
| @@ -78,6 +86,8 @@ namespace Tensorflow.Functions | |||||
| null); | null); | ||||
| func_graph.Exit(); | func_graph.Exit(); | ||||
| _captured_inputs = func_graph.external_captures; | _captured_inputs = func_graph.external_captures; | ||||
| _attrs = new Dictionary<string, string>(); | |||||
| _delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs); | |||||
| } | } | ||||
| /*public ConcreteFunction(Func<Tensors, Tensors> func, | /*public ConcreteFunction(Func<Tensors, Tensors> func, | ||||
| @@ -176,7 +186,7 @@ namespace Tensorflow.Functions | |||||
| { | { | ||||
| g = ops.get_default_graph(); | g = ops.get_default_graph(); | ||||
| } | } | ||||
| // TODO(Rinne); complete it with `_delayed_rewrite_functions`. | |||||
| _delayed_rewrite_functions.forward().AddToGraph(g); | |||||
| } | } | ||||
| public void SetExternalCaptures(IEnumerable<Tensor> captures) | public void SetExternalCaptures(IEnumerable<Tensor> captures) | ||||
| @@ -3,6 +3,7 @@ using System; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Linq; | using System.Linq; | ||||
| using System.Text; | using System.Text; | ||||
| using Tensorflow.Contexts; | |||||
| using Tensorflow.Graphs; | using Tensorflow.Graphs; | ||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| @@ -11,9 +12,20 @@ namespace Tensorflow.Functions | |||||
| public class EagerDefinedFunction | public class EagerDefinedFunction | ||||
| { | { | ||||
| public int _num_outputs; | public int _num_outputs; | ||||
| public string Name => _func_graph.FuncName; | |||||
| FuncGraph _func_graph; | FuncGraph _func_graph; | ||||
| FunctionDef _definition; | |||||
| public string Name => _func_graph.FuncName; | |||||
| public FunctionDef Definition | |||||
| { | |||||
| get | |||||
| { | |||||
| if(_definition is null) | |||||
| { | |||||
| _definition = _get_definition(); | |||||
| } | |||||
| return _definition; | |||||
| } | |||||
| } | |||||
| public EagerDefinedFunction(string name, FuncGraph graph, | public EagerDefinedFunction(string name, FuncGraph graph, | ||||
| Tensors inputs, Tensors outputs, | Tensors inputs, Tensors outputs, | ||||
| Dictionary<string, string> attrs) | Dictionary<string, string> attrs) | ||||
| @@ -46,5 +58,39 @@ namespace Tensorflow.Functions | |||||
| return results; | return results; | ||||
| } | } | ||||
| public void AddToGraph(Graph g = null) | |||||
| { | |||||
| if(g is null && tf.Context.executing_eagerly()) | |||||
| { | |||||
| var ctx = tf.Context; | |||||
| if (!ctx.has_function(this.Name)) | |||||
| { | |||||
| ctx.add_function_def(Definition); | |||||
| } | |||||
| } | |||||
| else | |||||
| { | |||||
| if (!g.IsFunction(Name)) | |||||
| { | |||||
| g.AddFunction(this); | |||||
| } | |||||
| foreach(var f in _func_graph.Functions.Values) | |||||
| { | |||||
| if (!g.IsFunction(f.Name)) | |||||
| { | |||||
| g.AddFunction(f); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| private FunctionDef _get_definition() | |||||
| { | |||||
| var buffer = c_api_util.tf_buffer(); | |||||
| // TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef | |||||
| var proto_data = c_api.TF_GetBuffer(buffer); | |||||
| throw new NotImplementedException(); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -0,0 +1,41 @@ | |||||
| using System; | |||||
| using System.Collections.Generic; | |||||
| using System.Text; | |||||
| using Tensorflow.Graphs; | |||||
| namespace Tensorflow.Functions | |||||
| { | |||||
| public class DelayedRewriteGradientFunctions | |||||
| { | |||||
| static readonly string _INFERENCE_PREFIX = "__inference_"; | |||||
| static readonly string _BACKWARD_PREFIX = "__backward_"; | |||||
| static readonly string _FORWARD_PREFIX = "__forward_"; | |||||
| FuncGraph _func_graph; | |||||
| EagerDefinedFunction _inference_function; | |||||
| Dictionary<string, string> _attrs; | |||||
| int _num_inference_outputs; | |||||
| public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs) | |||||
| { | |||||
| _func_graph= func_graph; | |||||
| _inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name), | |||||
| _func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs); | |||||
| _attrs = attrs; | |||||
| _num_inference_outputs = _func_graph.Outputs.Length; | |||||
| } | |||||
| public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null) | |||||
| { | |||||
| if(input_tangents is not null) | |||||
| { | |||||
| throw new InvalidArgumentError($"unexpectedly got forwardprop information in " + | |||||
| $"a class that does not support forwardprop."); | |||||
| } | |||||
| return _inference_function; | |||||
| } | |||||
| private static string _inference_name(string name) | |||||
| { | |||||
| return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}"; | |||||
| } | |||||
| } | |||||
| } | |||||
| @@ -22,7 +22,6 @@ namespace Tensorflow.Graphs | |||||
| public override void OnEntry(MethodExecutionArgs args) | public override void OnEntry(MethodExecutionArgs args) | ||||
| { | { | ||||
| File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc"); | |||||
| // TODO: func_name can be cache in FullName + Args | // TODO: func_name can be cache in FullName + Args | ||||
| func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}"; | ||||
| @@ -15,6 +15,7 @@ public class FuncGraph : Graph, IDisposable | |||||
| public Tensors Inputs { get; set; } = new Tensors(); | public Tensors Inputs { get; set; } = new Tensors(); | ||||
| public Tensors Outputs { get; set; } = new Tensors(); | public Tensors Outputs { get; set; } = new Tensors(); | ||||
| public string Name { get; set; } | |||||
| public Dictionary<string, string> Attrs { get; set; } | public Dictionary<string, string> Attrs { get; set; } | ||||
| Dictionary<long, (Tensor, Tensor)> _captures | Dictionary<long, (Tensor, Tensor)> _captures | ||||
| @@ -39,7 +40,7 @@ public class FuncGraph : Graph, IDisposable | |||||
| outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
| while (outer_graph.building_function) | while (outer_graph.building_function) | ||||
| outer_graph = outer_graph.OuterGraph; | outer_graph = outer_graph.OuterGraph; | ||||
| _graph_key = name; | |||||
| _graph_key = Name = name; | |||||
| building_function = true; | building_function = true; | ||||
| } | } | ||||
| @@ -48,7 +49,7 @@ public class FuncGraph : Graph, IDisposable | |||||
| outer_graph = ops.get_default_graph(); | outer_graph = ops.get_default_graph(); | ||||
| while (outer_graph.building_function) | while (outer_graph.building_function) | ||||
| outer_graph = outer_graph.OuterGraph; | outer_graph = outer_graph.OuterGraph; | ||||
| _graph_key = name; | |||||
| _graph_key = Name = name; | |||||
| building_function = true; | building_function = true; | ||||
| Attrs = attrs; | Attrs = attrs; | ||||
| // Will to test if FuncGraph has memory leak | // Will to test if FuncGraph has memory leak | ||||
| @@ -19,6 +19,8 @@ using System.Collections; | |||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Collections.Specialized; | using System.Collections.Specialized; | ||||
| using System.Linq; | using System.Linq; | ||||
| using Tensorflow.Framework; | |||||
| using Tensorflow.Functions; | |||||
| using static Tensorflow.Binding; | using static Tensorflow.Binding; | ||||
| namespace Tensorflow | namespace Tensorflow | ||||
| @@ -85,6 +87,12 @@ namespace Tensorflow | |||||
| private int _next_id_counter; | private int _next_id_counter; | ||||
| private List<Operation> _unfetchable_ops = new List<Operation>(); | private List<Operation> _unfetchable_ops = new List<Operation>(); | ||||
| private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | private List<Tensor> _unfeedable_tensors = new List<Tensor>(); | ||||
| private Dictionary<string, EagerDefinedFunction> _functions = new(); | |||||
| private VersionDef _graph_def_versions = new VersionDef() | |||||
| { | |||||
| Producer = versions.GRAPH_DEF_VERSION, | |||||
| MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER | |||||
| }; | |||||
| public string _name_stack = ""; | public string _name_stack = ""; | ||||
| protected string _graph_key; | protected string _graph_key; | ||||
| @@ -120,6 +128,7 @@ namespace Tensorflow | |||||
| protected Graph outer_graph; | protected Graph outer_graph; | ||||
| public Graph OuterGraph => outer_graph; | public Graph OuterGraph => outer_graph; | ||||
| public Dictionary<string, EagerDefinedFunction> Functions => _functions; | |||||
| public Graph() | public Graph() | ||||
| { | { | ||||
| @@ -148,8 +157,23 @@ namespace Tensorflow | |||||
| public bool IsFunction(string name) | public bool IsFunction(string name) | ||||
| { | { | ||||
| // TODO(Rinne): deal with `_functions`. | |||||
| throw new NotImplementedException(); | |||||
| return _functions.ContainsKey(tf.compat.as_str(name)); | |||||
| } | |||||
| public void AddFunction(EagerDefinedFunction function) | |||||
| { | |||||
| _check_not_finalized(); | |||||
| var name = function.Name; | |||||
| // TODO(Rinne): deal with c_graph | |||||
| _functions[tf.compat.as_str(name)] = function; | |||||
| if(_graph_def_versions.MinConsumer < 12) | |||||
| { | |||||
| _graph_def_versions.MinConsumer = 12; | |||||
| } | |||||
| } | } | ||||
| private Tensor _as_graph_element(object obj) | private Tensor _as_graph_element(object obj) | ||||
| @@ -77,11 +77,8 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| } | } | ||||
| Dictionary<string, ConcreteFunction> loaded_gradients = new(); | Dictionary<string, ConcreteFunction> loaded_gradients = new(); | ||||
| int aa = 0; | |||||
| var temp = _sort_function_defs(library, function_deps); | |||||
| foreach (var fdef in temp) | |||||
| foreach (var fdef in _sort_function_defs(library, function_deps)) | |||||
| { | { | ||||
| aa++; | |||||
| var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types); | ||||
| if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name)) | ||||
| @@ -191,7 +188,7 @@ namespace Tensorflow.Training.Saving.SavedModel | |||||
| { | { | ||||
| if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall") | ||||
| { | { | ||||
| var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()]; | |||||
| var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name]; | |||||
| // TODO(Rinne): deal with `op._gradient_function`. | // TODO(Rinne): deal with `op._gradient_function`. | ||||
| } | } | ||||
| string gradient_op_type = null; | string gradient_op_type = null; | ||||
| @@ -375,6 +375,11 @@ namespace Tensorflow | |||||
| // Re-create everything. | // Re-create everything. | ||||
| foreach (var (node_id, proto) in _iter_all_nodes()) | foreach (var (node_id, proto) in _iter_all_nodes()) | ||||
| { | { | ||||
| if(node_id == 45) | |||||
| { | |||||
| // TODelete | |||||
| Console.WriteLine(); | |||||
| } | |||||
| if (nodes.ContainsKey(node_id)) | if (nodes.ContainsKey(node_id)) | ||||
| { | { | ||||
| continue; | continue; | ||||
| @@ -469,7 +474,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| } | } | ||||
| private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes) | |||||
| private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes) | |||||
| { | { | ||||
| if (_restored_concrete_functions.Contains(concrete_function_name)) | if (_restored_concrete_functions.Contains(concrete_function_name)) | ||||
| { | { | ||||
| @@ -572,7 +577,7 @@ namespace Tensorflow | |||||
| { | { | ||||
| SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id), | ||||
| SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null), | ||||
| SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(), | |||||
| SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies), | |||||
| SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable), | ||||
| SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(), | ||||
| _ => throw new NotImplementedException() | _ => throw new NotImplementedException() | ||||
| @@ -644,7 +649,7 @@ namespace Tensorflow | |||||
| } | } | ||||
| private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto, | ||||
| Dictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| IDictionary<Maybe<string, int>, Trackable> dependencies) | |||||
| { | { | ||||
| var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions); | ||||
| _setup_function_captures(proto.ConcreteFunctionName, dependencies); | _setup_function_captures(proto.ConcreteFunctionName, dependencies); | ||||
| @@ -62,7 +62,7 @@ public class SequentialModelLoad | |||||
| [TestMethod] | [TestMethod] | ||||
| public void Temp() | public void Temp() | ||||
| { | { | ||||
| var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func"); | |||||
| var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func"); | |||||
| model.summary(); | model.summary(); | ||||
| } | } | ||||
| } | } | ||||