Browse Source

add Resolve() to gradient.

tags/v0.20
Oceania2018 5 years ago
parent
commit
e3034fafa4
8 changed files with 56 additions and 40 deletions
  1. +5
    -5
      src/TensorFlowNET.Console/Program.cs
  2. +12
    -7
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  6. +5
    -5
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +21
    -16
      src/TensorFlowNET.Core/System/GarbageCollector.cs
  8. +9
    -3
      src/TensorFlowNET.Core/tensorflow.cs

+ 5
- 5
src/TensorFlowNET.Console/Program.cs View File

@@ -15,16 +15,16 @@ namespace Tensorflow
int batchSize = 1000; int batchSize = 1000;


// 1 million float tensor 58.5M. // 1 million float tensor 58.5M.
// mm.Execute(10, 100 * batchSize, cases.Constant);
mm.Execute(10, 100 * batchSize, cases.Constant);


// 100K float variable 80.5M. // 100K float variable 80.5M.
//mm.Execute(10, 10 * batchSize, cases.Variable);
mm.Execute(10, 10 * batchSize, cases.Variable);


// 1 million math add 36.5M. // 1 million math add 36.5M.
// mm.Execute(10, 100 * batchSize, cases.MathAdd);
mm.Execute(10, 100 * batchSize, cases.MathAdd);


// 100K gradient 211M.
mm.Execute(100, 1 * batchSize, cases.Gradient);
// 100K gradient 210M.
mm.Execute(10, 10 * batchSize, cases.Gradient);


Console.WriteLine("Finished."); Console.WriteLine("Finished.");
Console.ReadLine(); Console.ReadLine();


+ 12
- 7
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -2,7 +2,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Tensorflow.Gradients;
using static Tensorflow.Binding;


namespace Tensorflow.Eager namespace Tensorflow.Eager
{ {
@@ -11,15 +11,12 @@ namespace Tensorflow.Eager
public EagerTensor() : base(IntPtr.Zero) public EagerTensor() : base(IntPtr.Zero)
{ {
EagerTensorHandle = c_api.TFE_NewEagerTensor(); EagerTensorHandle = c_api.TFE_NewEagerTensor();
// _id = c_api.TFE_EagerTensorId(EagerTensorHandle);
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
} }


public EagerTensor(IntPtr handle) : base(IntPtr.Zero) public EagerTensor(IntPtr handle) : base(IntPtr.Zero)
{ {
EagerTensorHandle = handle; EagerTensorHandle = handle;
Resolve(); Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
} }


public EagerTensor(string value, string device_name) : base(value) public EagerTensor(string value, string device_name) : base(value)
@@ -28,7 +25,6 @@ namespace Tensorflow.Eager
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle); c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle);
Resolve(); Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
} }
public EagerTensor(NDArray value, string device_name) : base(value) public EagerTensor(NDArray value, string device_name) : base(value)
@@ -37,18 +33,21 @@ namespace Tensorflow.Eager
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status); tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle); c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle);
Resolve(); Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
} }


public EagerTensor Resolve() public EagerTensor Resolve()
{ {
_id = c_api.TFE_EagerTensorId(EagerTensorHandle);

if (tfe_tensor_handle == IntPtr.Zero) if (tfe_tensor_handle == IntPtr.Zero)
tfe_tensor_handle = c_api.TFE_EagerTensorHandle(EagerTensorHandle); tfe_tensor_handle = c_api.TFE_EagerTensorHandle(EagerTensorHandle);


if (_handle == IntPtr.Zero) if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(tfe_tensor_handle, status); _handle = c_api.TFE_TensorHandleResolve(tfe_tensor_handle, status);


_id = c_api.TFE_EagerTensorId(EagerTensorHandle);
/*print($"new Tensor {Id} {_handle.ToString("x16")}");
print($"new TensorHandle {Id} {tfe_tensor_handle.ToString("x16")}");
print($"new EagerTensor {Id} {EagerTensorHandle.ToString("x16")}");*/


GarbageCollector.Increase(_handle, GCItemType.TensorHandle); GarbageCollector.Increase(_handle, GCItemType.TensorHandle);
GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle); GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle);
@@ -62,6 +61,12 @@ namespace Tensorflow.Eager
GarbageCollector.Decrease(_handle); GarbageCollector.Decrease(_handle);
GarbageCollector.Decrease(tfe_tensor_handle); GarbageCollector.Decrease(tfe_tensor_handle);
GarbageCollector.Decrease(EagerTensorHandle); GarbageCollector.Decrease(EagerTensorHandle);

/*c_api.TF_DeleteTensor(_handle);
print($"deleting DeleteTensorHandle {Id} {tfe_tensor_handle.ToString("x16")}");
c_api.TFE_DeleteTensorHandle(tfe_tensor_handle);
print($"deleting DeleteEagerTensor {Id} {EagerTensorHandle.ToString("x16")}");
c_api.TFE_DeleteEagerTensor(EagerTensorHandle);*/
} }
} }
} }

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Eager
public IntPtr EagerTensorHandle { get; set; } public IntPtr EagerTensorHandle { get; set; }
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status)); public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status));


public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status);
// public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status);


public static int GetRank(IntPtr handle) public static int GetRank(IntPtr handle)
{ {


+ 2
- 2
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -84,7 +84,7 @@ namespace Tensorflow.Gradients
new [] { (source as EagerTensor).EagerTensorHandle }, 1, new [] { (source as EagerTensor).EagerTensorHandle }, 1,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length); results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true); status.Check(true);
return results[0];
return results[0].Resolve();
} }


public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources) public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
@@ -116,7 +116,7 @@ namespace Tensorflow.Gradients
_tape = null; _tape = null;
} }


return (results[0], results[1]);
return (results[0].Resolve(), results[1].Resolve());
} }


public void Dispose() public void Dispose()


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -314,7 +314,7 @@ namespace Tensorflow
}, 2, null, }, 2, null,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length); results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true); status.Check(true);
return (results[0], results[1]);
return (results[0].Resolve(), results[1].Resolve());
} }


var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 }); var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });


+ 5
- 5
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -59,16 +59,16 @@ namespace Tensorflow
return _op.outputs[0]; return _op.outputs[0];
} }


public static EagerTensor add_n(IntPtr[] inputs, string name = null)
public static IntPtr add_n(IntPtr[] inputs, string name = null)
{ {
var results = new[] { new EagerTensor() };
var results = new[] { c_api.TFE_NewEagerTensor() };
Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name, Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"AddN", name, "AddN", name,
inputs, inputs.Length, inputs, inputs.Length,
null, null,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
results, results.Length);
status.Check(true); status.Check(true);
return results[0].Resolve();
return results[0];
} }


/// <summary> /// <summary>
@@ -155,7 +155,7 @@ namespace Tensorflow
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims), op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims),
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length); results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true); status.Check(true);
return results[0];
return results[0].Resolve();
} }


var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims }); var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });


+ 21
- 16
src/TensorFlowNET.Core/System/GarbageCollector.cs View File

@@ -2,27 +2,33 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Timers; using System.Timers;
using static Tensorflow.Binding;


namespace Tensorflow namespace Tensorflow
{ {
public class GarbageCollector public class GarbageCollector
{ {
static Dictionary<IntPtr, GCItemCounter> container = new Dictionary<IntPtr, GCItemCounter>(); static Dictionary<IntPtr, GCItemCounter> container = new Dictionary<IntPtr, GCItemCounter>();
static Timer timer = null;
static object locker = new object();


public static void Increase(IntPtr handle, GCItemType type)
static object locker = new object();
public static void Init()
{ {
if(timer == null)
Task.Run(() =>
{ {
timer = new Timer(300);
// Hook up the Elapsed event for the timer.
timer.Elapsed += OnTimedEvent;
timer.AutoReset = true;
timer.Enabled = true;
}
while (true)
{
Thread.Sleep(100);
Recycle();
}
});


}

public static void Increase(IntPtr handle, GCItemType type)
{
if (container.ContainsKey(handle)) if (container.ContainsKey(handle))
{ {
container[handle].RefCounter++; container[handle].RefCounter++;
@@ -52,15 +58,13 @@ namespace Tensorflow
} }
} }


private static void OnTimedEvent(object source, ElapsedEventArgs e)
private static void Recycle()
{ {
timer.Stop();

// dispose before 1 sec // dispose before 1 sec
lock (locker) lock (locker)
{ {
var items = container.Values var items = container.Values
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).Milliseconds > 300)
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 100)
.ToArray(); .ToArray();


foreach (var item in items) foreach (var item in items)
@@ -70,12 +74,15 @@ namespace Tensorflow
switch (item.ItemType) switch (item.ItemType)
{ {
case GCItemType.TensorHandle: case GCItemType.TensorHandle:
// print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})");
c_api.TF_DeleteTensor(item.Handle); c_api.TF_DeleteTensor(item.Handle);
break; break;
case GCItemType.LocalTensorHandle: case GCItemType.LocalTensorHandle:
// print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})");
c_api.TFE_DeleteTensorHandle(item.Handle); c_api.TFE_DeleteTensorHandle(item.Handle);
break; break;
case GCItemType.EagerTensorHandle: case GCItemType.EagerTensorHandle:
// print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})");
c_api.TFE_DeleteEagerTensor(item.Handle); c_api.TFE_DeleteEagerTensor(item.Handle);
break; break;
default: default:
@@ -83,8 +90,6 @@ namespace Tensorflow
} }
} }
} }

timer.Start();
} }
} }
} }

+ 9
- 3
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -49,15 +49,16 @@ namespace Tensorflow


private unsafe void InitGradientEnvironment() private unsafe void InitGradientEnvironment()
{ {
GarbageCollector.Init();

var vspace = c_api.VSpace_Handle((shape, dims, dtype) => var vspace = c_api.VSpace_Handle((shape, dims, dtype) =>
{ {
var ones = constant_op.constant(1.0f, dtype: dtype) as EagerTensor; var ones = constant_op.constant(1.0f, dtype: dtype) as EagerTensor;
return ones.EagerTensorHandle; return ones.EagerTensorHandle;
}, (gradients) => }, (gradients) =>
{ {
var input_grads = gradients.Data.Select(x => new EagerTensor(x)).ToArray();
var add_n = gen_math_ops.add_n(input_grads) as EagerTensor;
return add_n.EagerTensorHandle;
var add_n = gen_math_ops.add_n(gradients.Data);
return add_n;
}); });


ops.RegisterFromAssembly(); ops.RegisterFromAssembly();
@@ -65,6 +66,9 @@ namespace Tensorflow


c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, num_attrs, output_grads, skip_input_indices) => c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, num_attrs, output_grads, skip_input_indices) =>
{ {
/*var input_tensors = new BindingArray(op_inputs);
var output_tensors = new BindingArray(op_outputs);
var output_grad_tensors = new BindingArray(output_grads);*/
var input_tensors = new BindingTensorArray(op_inputs).Data.Select(x => new EagerTensor(x)).ToArray(); var input_tensors = new BindingTensorArray(op_inputs).Data.Select(x => new EagerTensor(x)).ToArray();
var output_tensors = new BindingTensorArray(op_outputs).Data.Select(x => new EagerTensor(x)).ToArray(); var output_tensors = new BindingTensorArray(op_outputs).Data.Select(x => new EagerTensor(x)).ToArray();
var output_grad_tensors = new BindingTensorArray(output_grads).Data.Select(x => new EagerTensor(x)).ToArray(); var output_grad_tensors = new BindingTensorArray(output_grads).Data.Select(x => new EagerTensor(x)).ToArray();
@@ -74,8 +78,10 @@ namespace Tensorflow
{ {
NumInputs = input_tensors.Length, NumInputs = input_tensors.Length,
Inputs = input_tensors, Inputs = input_tensors,
// InputHandles = input_tensors.Data,
NumOutputs = output_tensors.Length, NumOutputs = output_tensors.Length,
Outputs = output_tensors, Outputs = output_tensors,
// OutputHandles = output_tensors.Data,
SkipInputIndices = skip_input_indices_param SkipInputIndices = skip_input_indices_param
}, output_grad_tensors); }, output_grad_tensors);




Loading…
Cancel
Save