Browse Source

Align some implementations of Graph and FuncGraph.

tags/v0.100.5-BERT-load
AsakusaRinne 2 years ago
parent
commit
4e6431ed85
13 changed files with 164 additions and 21 deletions
  1. +7
    -0
      src/TensorFlowNET.Core/Contexts/Context.cs
  2. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  3. +3
    -0
      src/TensorFlowNET.Core/Eager/c_api.eager.cs
  4. +11
    -1
      src/TensorFlowNET.Core/Framework/c_api_util.cs
  5. +13
    -3
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  6. +48
    -2
      src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs
  7. +41
    -0
      src/TensorFlowNET.Core/Functions/monomorphic_function.cs
  8. +0
    -1
      src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs
  9. +3
    -2
      src/TensorFlowNET.Core/Graphs/FuncGraph.cs
  10. +26
    -2
      src/TensorFlowNET.Core/Graphs/Graph.cs
  11. +2
    -5
      src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs
  12. +8
    -3
      src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs
  13. +1
    -1
      test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

+ 7
- 0
src/TensorFlowNET.Core/Contexts/Context.cs View File

@@ -162,6 +162,13 @@ namespace Tensorflow.Contexts
return c_api.TFE_ContextHasFunction(_handle, name);
}

public void add_function_def(FunctionDef fdef)
{
ensure_initialized();
var fdef_string = fdef.ToString();
c_api.TFE_ContextAddFunctionDef(_handle, fdef_string, fdef_string.Length);
}

public void restore_mode()
{
context_switches.Pop();


+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -358,7 +358,7 @@ namespace Tensorflow.Eager
break;
case TF_AttrType.TF_ATTR_FUNC:
if (value is ConcreteFunction func)
c_api.TFE_OpSetAttrFunctionName(op, key, func.Name, func.Name.Length);
c_api.TFE_OpSetAttrFunctionName(op, key, func.func_graph.FuncName, func.func_graph.FuncName.Length);
else
throw new NotImplementedException("TF_AttrType.TF_ATTR_FUNC");
break;


+ 3
- 0
src/TensorFlowNET.Core/Eager/c_api.eager.cs View File

@@ -30,6 +30,9 @@ namespace Tensorflow
[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetConfig(SafeContextOptionsHandle opts, byte[] proto, ulong proto_len, SafeStatusHandle status);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextAddFunctionDef(SafeContextHandle ctx, string serialized_function_def, int size);

[DllImport(TensorFlowLibName)]
public static extern void TFE_ContextOptionsSetDevicePlacementPolicy(SafeContextOptionsHandle opts, ContextDevicePlacementPolicy device_policy);



+ 11
- 1
src/TensorFlowNET.Core/Framework/c_api_util.cs View File

@@ -111,7 +111,17 @@ namespace Tensorflow

public static ImportGraphDefOptions ScopedTFImportGraphDefOptions() => new ImportGraphDefOptions();

public static Buffer tf_buffer(byte[] data) => new Buffer(data);
public static Buffer tf_buffer(byte[] data = null)
{
if(data is not null)
{
return new Buffer(data); ;
}
else
{
return new Buffer();
}
}

public static IEnumerable<Operation> new_tf_operations(Graph graph)
{


+ 13
- 3
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -15,11 +15,13 @@ namespace Tensorflow.Functions
{
protected IEnumerable<Tensor> _captured_inputs;
internal FuncGraph func_graph;
protected DelayedRewriteGradientFunctions _delayed_rewrite_functions;
protected Dictionary<string, string> _attrs;
internal ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;

public string Name => func_graph?.FuncName;
public string Name => _delayed_rewrite_functions.forward().Name;

public Tensor[] Outputs;
public Type ReturnType;
@@ -31,6 +33,8 @@ namespace Tensorflow.Functions
{
func_graph = new FuncGraph(name);
_captured_inputs = func_graph.external_captures;
_attrs= new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
}

public ConcreteFunction(FuncGraph graph, Dictionary<string, string> attrs = null)
@@ -38,7 +42,9 @@ namespace Tensorflow.Functions
func_graph = graph;
_captured_inputs = func_graph.external_captures;

ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
//ToGraph(graph.Inputs, graph.Outputs.Where(x => x != null).ToArray());
_attrs = attrs;
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
}

public ConcreteFunction(Func<Tensor, Tensor> func, TF_DataType dtype)
@@ -57,6 +63,8 @@ namespace Tensorflow.Functions
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
}

public ConcreteFunction(Func<Tensor, IDatasetV2> func, TF_DataType dtype)
@@ -78,6 +86,8 @@ namespace Tensorflow.Functions
null);
func_graph.Exit();
_captured_inputs = func_graph.external_captures;
_attrs = new Dictionary<string, string>();
_delayed_rewrite_functions = new DelayedRewriteGradientFunctions(func_graph, _attrs);
}

/*public ConcreteFunction(Func<Tensors, Tensors> func,
@@ -176,7 +186,7 @@ namespace Tensorflow.Functions
{
g = ops.get_default_graph();
}
// TODO(Rinne); complete it with `_delayed_rewrite_functions`.
_delayed_rewrite_functions.forward().AddToGraph(g);
}

public void SetExternalCaptures(IEnumerable<Tensor> captures)


+ 48
- 2
src/TensorFlowNET.Core/Functions/EagerDefinedFunction.cs View File

@@ -3,6 +3,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Tensorflow.Contexts;
using Tensorflow.Graphs;
using static Tensorflow.Binding;

@@ -11,9 +12,20 @@ namespace Tensorflow.Functions
public class EagerDefinedFunction
{
public int _num_outputs;
public string Name => _func_graph.FuncName;

FuncGraph _func_graph;
FunctionDef _definition;
public string Name => _func_graph.FuncName;
public FunctionDef Definition
{
get
{
if(_definition is null)
{
_definition = _get_definition();
}
return _definition;
}
}
public EagerDefinedFunction(string name, FuncGraph graph,
Tensors inputs, Tensors outputs,
Dictionary<string, string> attrs)
@@ -46,5 +58,39 @@ namespace Tensorflow.Functions

return results;
}

public void AddToGraph(Graph g = null)
{
if(g is null && tf.Context.executing_eagerly())
{
var ctx = tf.Context;
if (!ctx.has_function(this.Name))
{
ctx.add_function_def(Definition);
}
}
else
{
if (!g.IsFunction(Name))
{
g.AddFunction(this);
}
foreach(var f in _func_graph.Functions.Values)
{
if (!g.IsFunction(f.Name))
{
g.AddFunction(f);
}
}
}
}

private FunctionDef _get_definition()
{
var buffer = c_api_util.tf_buffer();
// TODO(Rinne): pywrap_tf_session.TF_FunctionToFunctionDef
var proto_data = c_api.TF_GetBuffer(buffer);
throw new NotImplementedException();
}
}
}

+ 41
- 0
src/TensorFlowNET.Core/Functions/monomorphic_function.cs View File

@@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Graphs;

namespace Tensorflow.Functions
{
public class DelayedRewriteGradientFunctions
{
static readonly string _INFERENCE_PREFIX = "__inference_";
static readonly string _BACKWARD_PREFIX = "__backward_";
static readonly string _FORWARD_PREFIX = "__forward_";
FuncGraph _func_graph;
EagerDefinedFunction _inference_function;
Dictionary<string, string> _attrs;
int _num_inference_outputs;
public DelayedRewriteGradientFunctions(FuncGraph func_graph, Dictionary<string, string> attrs)
{
_func_graph= func_graph;
_inference_function = new EagerDefinedFunction(_inference_name(_func_graph.Name),
_func_graph, _func_graph.Inputs, _func_graph.Outputs, attrs);
_attrs = attrs;
_num_inference_outputs = _func_graph.Outputs.Length;
}

public EagerDefinedFunction forward(Tensors inference_args = null, Tensors input_tangents = null)
{
if(input_tangents is not null)
{
throw new InvalidArgumentError($"unexpectedly got forwardprop information in " +
$"a class that does not support forwardprop.");
}
return _inference_function;
}

private static string _inference_name(string name)
{
return $"{_INFERENCE_PREFIX}{name}_{ops.uid()}";
}
}
}

+ 0
- 1
src/TensorFlowNET.Core/Graphs/AutoGraphAttribute.cs View File

@@ -22,7 +22,6 @@ namespace Tensorflow.Graphs

public override void OnEntry(MethodExecutionArgs args)
{
File.WriteAllText(@"D:\temp\for_test.txt", "jyfgjyfjhfjhc");
// TODO: func_name can be cache in FullName + Args
func_name = $"{args.Method.DeclaringType.FullName}.{args.Method.Name}";



+ 3
- 2
src/TensorFlowNET.Core/Graphs/FuncGraph.cs View File

@@ -15,6 +15,7 @@ public class FuncGraph : Graph, IDisposable

public Tensors Inputs { get; set; } = new Tensors();
public Tensors Outputs { get; set; } = new Tensors();
public string Name { get; set; }
public Dictionary<string, string> Attrs { get; set; }

Dictionary<long, (Tensor, Tensor)> _captures
@@ -39,7 +40,7 @@ public class FuncGraph : Graph, IDisposable
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
_graph_key = Name = name;
building_function = true;
}

@@ -48,7 +49,7 @@ public class FuncGraph : Graph, IDisposable
outer_graph = ops.get_default_graph();
while (outer_graph.building_function)
outer_graph = outer_graph.OuterGraph;
_graph_key = name;
_graph_key = Name = name;
building_function = true;
Attrs = attrs;
// Will to test if FuncGraph has memory leak


+ 26
- 2
src/TensorFlowNET.Core/Graphs/Graph.cs View File

@@ -19,6 +19,8 @@ using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq;
using Tensorflow.Framework;
using Tensorflow.Functions;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -85,6 +87,12 @@ namespace Tensorflow
private int _next_id_counter;
private List<Operation> _unfetchable_ops = new List<Operation>();
private List<Tensor> _unfeedable_tensors = new List<Tensor>();
private Dictionary<string, EagerDefinedFunction> _functions = new();
private VersionDef _graph_def_versions = new VersionDef()
{
Producer = versions.GRAPH_DEF_VERSION,
MinConsumer = versions.GRAPH_DEF_VERSION_MIN_CONSUMER
};

public string _name_stack = "";
protected string _graph_key;
@@ -120,6 +128,7 @@ namespace Tensorflow

protected Graph outer_graph;
public Graph OuterGraph => outer_graph;
public Dictionary<string, EagerDefinedFunction> Functions => _functions;

public Graph()
{
@@ -148,8 +157,23 @@ namespace Tensorflow

public bool IsFunction(string name)
{
// TODO(Rinne): deal with `_functions`.
throw new NotImplementedException();
return _functions.ContainsKey(tf.compat.as_str(name));
}

public void AddFunction(EagerDefinedFunction function)
{
_check_not_finalized();

var name = function.Name;

// TODO(Rinne): deal with c_graph

_functions[tf.compat.as_str(name)] = function;

if(_graph_def_versions.MinConsumer < 12)
{
_graph_def_versions.MinConsumer = 12;
}
}

private Tensor _as_graph_element(object obj)


+ 2
- 5
src/TensorFlowNET.Core/Training/Saving/SavedModel/function_deserialization.cs View File

@@ -77,11 +77,8 @@ namespace Tensorflow.Training.Saving.SavedModel
}

Dictionary<string, ConcreteFunction> loaded_gradients = new();
int aa = 0;
var temp = _sort_function_defs(library, function_deps);
foreach (var fdef in temp)
foreach (var fdef in _sort_function_defs(library, function_deps))
{
aa++;
var orig_name = _fix_fdef_in_place(fdef, functions, load_shared_name_suffix, new_gradient_op_types);

if(saved_object_graph is not null && saved_object_graph.ConcreteFunctions.ContainsKey(orig_name))
@@ -191,7 +188,7 @@ namespace Tensorflow.Training.Saving.SavedModel
{
if(op.op.type == "StatefulPartitionedCall" || op.op.type == "PartitionedCall")
{
var function = renamed_functions[tf.compat.as_bytes(op.op.node_def.Attr["f"].Func.Name).ToString()];
var function = renamed_functions[op.op.node_def.Attr["f"].Func.Name];
// TODO(Rinne): deal with `op._gradient_function`.
}
string gradient_op_type = null;


+ 8
- 3
src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs View File

@@ -375,6 +375,11 @@ namespace Tensorflow
// Re-create everything.
foreach (var (node_id, proto) in _iter_all_nodes())
{
if(node_id == 45)
{
// TODelete
Console.WriteLine();
}
if (nodes.ContainsKey(node_id))
{
continue;
@@ -469,7 +474,7 @@ namespace Tensorflow
}
}

private void _setup_function_captures(string concrete_function_name, Dictionary<Maybe<string, int>, Trackable> nodes)
private void _setup_function_captures(string concrete_function_name, IDictionary<Maybe<string, int>, Trackable> nodes)
{
if (_restored_concrete_functions.Contains(concrete_function_name))
{
@@ -572,7 +577,7 @@ namespace Tensorflow
{
SavedObject.KindOneofCase.UserObject => _recreate_user_object(proto.UserObject, node_id),
SavedObject.KindOneofCase.Function => _recreate_function(proto.Function, null),
SavedObject.KindOneofCase.BareConcreteFunction => throw new NotImplementedException(),
SavedObject.KindOneofCase.BareConcreteFunction => _recreate_bare_concrete_function(proto.BareConcreteFunction, dependencies),
SavedObject.KindOneofCase.Variable => _recreate_variable(proto.Variable),
SavedObject.KindOneofCase.CapturedTensor => throw new NotImplementedException(),
_ => throw new NotImplementedException()
@@ -644,7 +649,7 @@ namespace Tensorflow
}

private (ConcreteFunction, Action<object, object, object>) _recreate_bare_concrete_function(SavedBareConcreteFunction proto,
Dictionary<Maybe<string, int>, Trackable> dependencies)
IDictionary<Maybe<string, int>, Trackable> dependencies)
{
var fn = function_deserialization.setup_bare_concrete_function(proto, _concrete_functions);
_setup_function_captures(proto.ConcreteFunctionName, dependencies);


+ 1
- 1
test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs View File

@@ -62,7 +62,7 @@ public class SequentialModelLoad
[TestMethod]
public void Temp()
{
var model = tf.keras.models.load_model(@"D:\development\tf.net\tf_test\python_func");
var model = tf.keras.models.load_model(@"C:\Work\tf.net\tf_test\python_func");
model.summary();
}
}

Loading…
Cancel
Save