Browse Source

add Resolve() to gradient.

tags/v0.20
Oceania2018 5 years ago
parent
commit
e3034fafa4
8 changed files with 56 additions and 40 deletions
  1. +5
    -5
      src/TensorFlowNET.Console/Program.cs
  2. +12
    -7
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  3. +1
    -1
      src/TensorFlowNET.Core/Eager/EagerTensor.cs
  4. +2
    -2
      src/TensorFlowNET.Core/Gradients/GradientTape.cs
  5. +1
    -1
      src/TensorFlowNET.Core/Operations/gen_array_ops.cs
  6. +5
    -5
      src/TensorFlowNET.Core/Operations/gen_math_ops.cs
  7. +21
    -16
      src/TensorFlowNET.Core/System/GarbageCollector.cs
  8. +9
    -3
      src/TensorFlowNET.Core/tensorflow.cs

+ 5
- 5
src/TensorFlowNET.Console/Program.cs View File

@@ -15,16 +15,16 @@ namespace Tensorflow
int batchSize = 1000;

// 1 million float tensor 58.5M.
// mm.Execute(10, 100 * batchSize, cases.Constant);
mm.Execute(10, 100 * batchSize, cases.Constant);

// 100K float variable 80.5M.
//mm.Execute(10, 10 * batchSize, cases.Variable);
mm.Execute(10, 10 * batchSize, cases.Variable);

// 1 million math add 36.5M.
// mm.Execute(10, 100 * batchSize, cases.MathAdd);
mm.Execute(10, 100 * batchSize, cases.MathAdd);

// 100K gradient 211M.
mm.Execute(100, 1 * batchSize, cases.Gradient);
// 100K gradient 210M.
mm.Execute(10, 10 * batchSize, cases.Gradient);

Console.WriteLine("Finished.");
Console.ReadLine();


+ 12
- 7
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -2,7 +2,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Tensorflow.Gradients;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -11,15 +11,12 @@ namespace Tensorflow.Eager
public EagerTensor() : base(IntPtr.Zero)
{
EagerTensorHandle = c_api.TFE_NewEagerTensor();
// _id = c_api.TFE_EagerTensorId(EagerTensorHandle);
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
}

public EagerTensor(IntPtr handle) : base(IntPtr.Zero)
{
EagerTensorHandle = handle;
Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
}

public EagerTensor(string value, string device_name) : base(value)
@@ -28,7 +25,6 @@ namespace Tensorflow.Eager
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle);
Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
}
public EagerTensor(NDArray value, string device_name) : base(value)
@@ -37,18 +33,21 @@ namespace Tensorflow.Eager
tfe_tensor_handle = c_api.TFE_NewTensorHandle(_handle, status);
c_api.TFE_SetEagerTensorHandle(EagerTensorHandle, tfe_tensor_handle);
Resolve();
// print($"new EagerTensorHandle {EagerTensorHandle.ToString("x16")} {Id}");
}

public EagerTensor Resolve()
{
_id = c_api.TFE_EagerTensorId(EagerTensorHandle);

if (tfe_tensor_handle == IntPtr.Zero)
tfe_tensor_handle = c_api.TFE_EagerTensorHandle(EagerTensorHandle);

if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(tfe_tensor_handle, status);

_id = c_api.TFE_EagerTensorId(EagerTensorHandle);
/*print($"new Tensor {Id} {_handle.ToString("x16")}");
print($"new TensorHandle {Id} {tfe_tensor_handle.ToString("x16")}");
print($"new EagerTensor {Id} {EagerTensorHandle.ToString("x16")}");*/

GarbageCollector.Increase(_handle, GCItemType.TensorHandle);
GarbageCollector.Increase(tfe_tensor_handle, GCItemType.LocalTensorHandle);
@@ -62,6 +61,12 @@ namespace Tensorflow.Eager
GarbageCollector.Decrease(_handle);
GarbageCollector.Decrease(tfe_tensor_handle);
GarbageCollector.Decrease(EagerTensorHandle);

/*c_api.TF_DeleteTensor(_handle);
print($"deleting DeleteTensorHandle {Id} {tfe_tensor_handle.ToString("x16")}");
c_api.TFE_DeleteTensorHandle(tfe_tensor_handle);
print($"deleting DeleteEagerTensor {Id} {EagerTensorHandle.ToString("x16")}");
c_api.TFE_DeleteEagerTensor(EagerTensorHandle);*/
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Eager/EagerTensor.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow.Eager
public IntPtr EagerTensorHandle { get; set; }
public override string Device => c_api.StringPiece(c_api.TFE_TensorHandleDeviceName(tfe_tensor_handle, status));

public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status);
// public override int rank => c_api.TFE_TensorHandleNumDims(tfe_tensor_handle, status);

public static int GetRank(IntPtr handle)
{


+ 2
- 2
src/TensorFlowNET.Core/Gradients/GradientTape.cs View File

@@ -84,7 +84,7 @@ namespace Tensorflow.Gradients
new [] { (source as EagerTensor).EagerTensorHandle }, 1,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true);
return results[0];
return results[0].Resolve();
}

public unsafe (Tensor, Tensor) gradient(Tensor target, (ResourceVariable, ResourceVariable) sources)
@@ -116,7 +116,7 @@ namespace Tensorflow.Gradients
_tape = null;
}

return (results[0], results[1]);
return (results[0].Resolve(), results[1].Resolve());
}

public void Dispose()


+ 1
- 1
src/TensorFlowNET.Core/Operations/gen_array_ops.cs View File

@@ -314,7 +314,7 @@ namespace Tensorflow
}, 2, null,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true);
return (results[0], results[1]);
return (results[0].Resolve(), results[1].Resolve());
}

var _op = _op_def_lib._apply_op_helper("BroadcastGradientArgs", name, new { s0, s1 });


+ 5
- 5
src/TensorFlowNET.Core/Operations/gen_math_ops.cs View File

@@ -59,16 +59,16 @@ namespace Tensorflow
return _op.outputs[0];
}

public static EagerTensor add_n(IntPtr[] inputs, string name = null)
public static IntPtr add_n(IntPtr[] inputs, string name = null)
{
var results = new[] { new EagerTensor() };
var results = new[] { c_api.TFE_NewEagerTensor() };
Status status = c_api.TFE_FastPathExecute(tf.context, tf.context.device_name,
"AddN", name,
inputs, inputs.Length,
null,
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
results, results.Length);
status.Check(true);
return results[0].Resolve();
return results[0];
}

/// <summary>
@@ -155,7 +155,7 @@ namespace Tensorflow
op => wrap_tfe_src.SetOpAttrs(op, "keep_dims", keep_dims),
results.Select(x => x.EagerTensorHandle).ToArray(), results.Length);
status.Check(true);
return results[0];
return results[0].Resolve();
}

var _op = _op_def_lib._apply_op_helper("Mean", name, args: new { input, reduction_indices = axis, keep_dims = keep_dims });


+ 21
- 16
src/TensorFlowNET.Core/System/GarbageCollector.cs View File

@@ -2,27 +2,33 @@
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Timers;
using static Tensorflow.Binding;

namespace Tensorflow
{
public class GarbageCollector
{
static Dictionary<IntPtr, GCItemCounter> container = new Dictionary<IntPtr, GCItemCounter>();
static Timer timer = null;
static object locker = new object();

public static void Increase(IntPtr handle, GCItemType type)
static object locker = new object();
public static void Init()
{
if(timer == null)
Task.Run(() =>
{
timer = new Timer(300);
// Hook up the Elapsed event for the timer.
timer.Elapsed += OnTimedEvent;
timer.AutoReset = true;
timer.Enabled = true;
}
while (true)
{
Thread.Sleep(100);
Recycle();
}
});

}

public static void Increase(IntPtr handle, GCItemType type)
{
if (container.ContainsKey(handle))
{
container[handle].RefCounter++;
@@ -52,15 +58,13 @@ namespace Tensorflow
}
}

private static void OnTimedEvent(object source, ElapsedEventArgs e)
private static void Recycle()
{
timer.Stop();

// dispose before 1 sec
lock (locker)
{
var items = container.Values
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).Milliseconds > 300)
.Where(x => x.RefCounter <= 0 && (DateTime.Now - x.LastUpdateTime).TotalMilliseconds > 100)
.ToArray();

foreach (var item in items)
@@ -70,12 +74,15 @@ namespace Tensorflow
switch (item.ItemType)
{
case GCItemType.TensorHandle:
// print($"c_api.TF_DeleteTensor({item.Handle.ToString("x16")})");
c_api.TF_DeleteTensor(item.Handle);
break;
case GCItemType.LocalTensorHandle:
// print($"c_api.TFE_DeleteTensorHandle({item.Handle.ToString("x16")})");
c_api.TFE_DeleteTensorHandle(item.Handle);
break;
case GCItemType.EagerTensorHandle:
// print($"c_api.TFE_DeleteEagerTensor({item.Handle.ToString("x16")})");
c_api.TFE_DeleteEagerTensor(item.Handle);
break;
default:
@@ -83,8 +90,6 @@ namespace Tensorflow
}
}
}

timer.Start();
}
}
}

+ 9
- 3
src/TensorFlowNET.Core/tensorflow.cs View File

@@ -49,15 +49,16 @@ namespace Tensorflow

private unsafe void InitGradientEnvironment()
{
GarbageCollector.Init();

var vspace = c_api.VSpace_Handle((shape, dims, dtype) =>
{
var ones = constant_op.constant(1.0f, dtype: dtype) as EagerTensor;
return ones.EagerTensorHandle;
}, (gradients) =>
{
var input_grads = gradients.Data.Select(x => new EagerTensor(x)).ToArray();
var add_n = gen_math_ops.add_n(input_grads) as EagerTensor;
return add_n.EagerTensorHandle;
var add_n = gen_math_ops.add_n(gradients.Data);
return add_n;
});

ops.RegisterFromAssembly();
@@ -65,6 +66,9 @@ namespace Tensorflow

c_api.TFE_RegisterGradientFunction((op_name, op_inputs, op_outputs, num_attrs, output_grads, skip_input_indices) =>
{
/*var input_tensors = new BindingArray(op_inputs);
var output_tensors = new BindingArray(op_outputs);
var output_grad_tensors = new BindingArray(output_grads);*/
var input_tensors = new BindingTensorArray(op_inputs).Data.Select(x => new EagerTensor(x)).ToArray();
var output_tensors = new BindingTensorArray(op_outputs).Data.Select(x => new EagerTensor(x)).ToArray();
var output_grad_tensors = new BindingTensorArray(output_grads).Data.Select(x => new EagerTensor(x)).ToArray();
@@ -74,8 +78,10 @@ namespace Tensorflow
{
NumInputs = input_tensors.Length,
Inputs = input_tensors,
// InputHandles = input_tensors.Data,
NumOutputs = output_tensors.Length,
Outputs = output_tensors,
// OutputHandles = output_tensors.Data,
SkipInputIndices = skip_input_indices_param
}, output_grad_tensors);



Loading…
Cancel
Save