Browse Source

fix TensorFlowOpLayer memory leak.

tags/TensorFlowOpLayer
Oceania2018 4 years ago
parent
commit
1460ec899b
6 changed files with 89 additions and 41 deletions
  1. +3
    -1
      src/TensorFlowNET.Core/Functions/ConcreteFunction.cs
  2. +4
    -2
      src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs
  3. +44
    -37
      src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs
  4. +4
    -0
      src/TensorFlowNET.Core/ops.cs
  5. +3
    -0
      src/TensorFlowNET.Keras/Engine/Layer.cs
  6. +31
    -1
      src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs

+ 3
- 1
src/TensorFlowNET.Core/Functions/ConcreteFunction.cs View File

@@ -13,6 +13,7 @@ namespace Tensorflow.Functions
public class ConcreteFunction
{
FuncGraph func_graph;
ForwardBackwardCall forward_backward;
public Tensor[] Inputs => func_graph.Inputs;
public Tensor[] CapturedInputs => func_graph.external_captures;

@@ -151,7 +152,8 @@ namespace Tensorflow.Functions
return tf.Runner.Execute(tf.Context, func_graph.FuncName, func_graph.Outputs.Length, args, attrs);
}

var forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
if (forward_backward == null)
forward_backward = SelectForwardAndBackwardFunctions(args, possible_gradient_type, executing_eagerly);
var (forward_function, args_with_tangents) = forward_backward.Forward();
Tensors flat_outputs = null;
if (executing_eagerly)


+ 4
- 2
src/TensorFlowNET.Core/Functions/ForwardBackwardCall.cs View File

@@ -13,6 +13,7 @@ namespace Tensorflow.Functions
Tensors _inference_args;
Tensors _input_tangents;
bool _tape_watching;
EagerDefinedFunction forward_function;

public ForwardBackwardCall(TapeGradientFunctions functions,
Tensors inference_args,
@@ -22,10 +23,11 @@ namespace Tensorflow.Functions
_inference_args = inference_args;
_tape_watching = tape_watching;
}
public (EagerDefinedFunction, Tensors) Forward()
{
var forward_function = _functions.Forward(_inference_args);
if (forward_function == null)
forward_function = _functions.Forward(_inference_args);
return (forward_function, _inference_args);
}



+ 44
- 37
src/TensorFlowNET.Core/Functions/TapeGradientFunctions.cs View File

@@ -25,6 +25,7 @@ namespace Tensorflow.Functions
protected List<int> _forwardprop_output_indices;
protected int _num_forwardprop_outputs;
protected ConcreteFunction _backward;
BackwardFunction _backward_function_wrapper;

public TapeGradientFunctions(FuncGraph func_graph,
bool need_gradients_for_jvps)
@@ -58,60 +59,66 @@ namespace Tensorflow.Functions
/// <returns></returns>
(BackwardFunction, Tensors) _wrap_backward_function(FuncGraph forward_graph, ConcreteFunction backward, Tensors outputs)
{
var capture_mapping = new Dictionary<long, Tensor>();
foreach(var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach(var capture in backward.CapturedInputs)
{
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

var backward_function_inputs = backward.Inputs.Length - backward.CapturedInputs.Length;
var recorded_outputs = new Tensors();
var relevant_outputs = outputs;
var trainable_recorded_outputs = 0;
var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(relevant_outputs))
foreach (var (output_index, output) in enumerate(outputs))
{
if (trainable_recorded_outputs < backward_function_inputs)
recorded_outputs.Add(output);
if (gradients_util.IsTrainable(output))
trainable_recorded_outputs += 1;
else
skip_positions.Add(output_index);
}

BackwardFunction _backward_function_wrapper = (args, unneeded_gradients) =>
if(_backward_function_wrapper == null)
{
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
var capture_mapping = new Dictionary<long, Tensor>();
foreach (var (i, output) in enumerate(outputs))
capture_mapping[forward_graph.Outputs[i].Id] = output;

var remapped_captures = new Tensors();
foreach (var capture in backward.CapturedInputs)
{
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
if (capture_mapping.ContainsKey(capture.Id))
remapped_captures.Add(capture_mapping[capture.Id]);
}

tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);

foreach (var unneeded_gradient_index in unneeded_gradients)
var skip_positions = new List<int>();
foreach (var (output_index, output) in enumerate(outputs))
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
if (!gradients_util.IsTrainable(output))
skip_positions.Add(output_index);
}

return gradients;
};
_backward_function_wrapper = (args, unneeded_gradients) =>
{
var processed_args = new Tensors();
var input_index = 0;
foreach (var (output_index, arg) in enumerate(args))
{
if (skip_positions.Contains(output_index))
continue;
if (arg == null)
throw new NotImplementedException("");
processed_args.Add(arg);
input_index += 1;
if (input_index >= backward_function_inputs)
break;
}

tf.Logger.Debug($"Invoke backward function: {backward.Name}");
var gradients = backward.CallFlat(processed_args, remapped_captures);

foreach (var unneeded_gradient_index in unneeded_gradients)
{
var index = Convert.ToInt32(unneeded_gradient_index);
if (gradients.Length <= index)
gradients.Insert(index, null);
}

return gradients;
};
}

return (_backward_function_wrapper, recorded_outputs);
}


+ 4
- 0
src/TensorFlowNET.Core/ops.cs View File

@@ -376,6 +376,10 @@ namespace Tensorflow
public static int uid_function()
=> Interlocked.Increment(ref uid_number_for_function);

static int uid_number_for_layer = 0;
public static int uid_layer()
=> Interlocked.Increment(ref uid_number_for_layer);

public static void reset_uid()
{
uid_number = -1;


+ 3
- 0
src/TensorFlowNET.Keras/Engine/Layer.cs View File

@@ -66,6 +66,8 @@ namespace Tensorflow.Keras.Engine
protected List<IVariableV1> non_trainable_weights;
public List<IVariableV1> non_trainable_variables => non_trainable_weights;

protected int id;
public int Id => id;
protected string name;
protected string base_name;
public string Name => name;
@@ -96,6 +98,7 @@ namespace Tensorflow.Keras.Engine
built = false;
SupportsMasking = false;

id = ops.uid_layer();
_init_set_name(args.Name);
trainable_weights = new List<IVariableV1>();
non_trainable_weights = new List<IVariableV1>();


+ 31
- 1
src/TensorFlowNET.Keras/Layers/TensorFlowOpLayer.cs View File

@@ -8,6 +8,7 @@ using Tensorflow.Graphs;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.Engine;
using static Tensorflow.Binding;
using Tensorflow.Functions;

namespace Tensorflow.Keras.Layers
{
@@ -35,10 +36,39 @@ namespace Tensorflow.Keras.Layers
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
{
if (tf.Context.executing_eagerly())
return _defun_call(inputs);
return DeFunCall(inputs);
return MakOp(inputs);
}

ConcreteFunction function;
Tensors DeFunCall(Tensors inputs)
{
if(function == null)
{
function = new ConcreteFunction(name);
function.Enter();

int i = 0;
var graph_inputs = inputs.Select(x => tf.placeholder(x.dtype, shape: x.shape, name: $"defun_inputs_{i++}")).ToArray();
var graph_outputs = MakOp(graph_inputs);
graph_outputs = mark_as_return(graph_outputs);

function.ToGraph(graph_inputs, graph_outputs);
function.Exit();
}

var outputs = function.FilteredCall(inputs);
return outputs;
}

Tensors mark_as_return(Tensors tensors)
{
var result = new Tensors();
foreach (var tensor in tensors)
result.Add(array_ops.identity(tensor));
return result;
}

[AutoGraph]
Tensors _defun_call(Tensors inputs)
=> MakOp(inputs);


Loading…
Cancel
Save