Browse Source

Remove AllocationReferenceHolder.

tags/yolov3
Oceania2018 4 years ago
parent
commit
ceaa09c842
17 changed files with 208 additions and 142 deletions
  1. +25
    -13
      src/TensorFlowNET.Console/MemoryMonitor.cs
  2. +49
    -6
      src/TensorFlowNET.Console/MemoryTestingCases.cs
  3. +10
    -4
      src/TensorFlowNET.Console/Program.cs
  4. +4
    -0
      src/TensorFlowNET.Console/Tensorflow.Console.csproj
  5. +21
    -22
      src/TensorFlowNET.Core/DisposableObject.cs
  6. +6
    -10
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs
  7. +8
    -8
      src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs
  8. +22
    -30
      src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs
  9. +8
    -0
      src/TensorFlowNET.Core/Eager/SafeOpHandle.cs
  10. +4
    -0
      src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs
  11. +1
    -1
      src/TensorFlowNET.Core/Tensorflow.Binding.csproj
  12. +10
    -13
      src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs
  13. +31
    -17
      src/TensorFlowNET.Core/Tensors/Tensor.cs
  14. +7
    -1
      src/TensorFlowNET.Core/Tensors/Tensors.cs
  15. +1
    -1
      src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs
  16. +0
    -15
      src/TensorFlowNET.Core/Variables/ResourceVariable.cs
  17. +1
    -1
      src/TensorFlowNET.Core/Variables/_UnreadVariable.cs

+ 25
- 13
src/TensorFlowNET.Console/MemoryMonitor.cs View File

@@ -1,5 +1,8 @@
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using NumSharp;
using static Tensorflow.Binding;

namespace Tensorflow
@@ -9,26 +12,35 @@ namespace Tensorflow
public void WarmUp()
{
print($"tensorflow native version: v{tf.VERSION}");
var a = tf.constant(np.ones(10, 10));
var b = tf.Variable(a);
var c = tf.Variable(b);
var d = b * c;
print(d.numpy());

GC.WaitForPendingFinalizers();
GC.Collect();
Thread.Sleep(1000);
}

public void Execute(int epoch, int iterate, Action<int> process)
{
/*GC.Collect();
GC.WaitForPendingFinalizers();
GC.Collect();*/

print($"{process.Method.Name} started...");
for (int i = 0; i < epoch; i++)

// new thread to run
Task.Run(() =>
{
var initialMemory = Process.GetCurrentProcess().PrivateMemorySize64;// GC.GetTotalMemory(true);
process(iterate);
var finalMemory = Process.GetCurrentProcess().PrivateMemorySize64; //GC.GetTotalMemory(true);
print($"Epoch {i}: {Format(finalMemory - initialMemory)}.");
}
for (int i = 0; i < epoch; i++)
{
var initialMemory = Process.GetCurrentProcess().PrivateMemorySize64;// GC.GetTotalMemory(true);
process(iterate);
var finalMemory = Process.GetCurrentProcess().PrivateMemorySize64; //GC.GetTotalMemory(true);
print($"Epoch {i}: {Format(finalMemory - initialMemory)}.");

GC.Collect();
GC.WaitForPendingFinalizers();
GC.Collect();
GC.Collect();
GC.WaitForPendingFinalizers();
}
}).Wait();

print($"Total {process.Method.Name} usage {Format(Process.GetCurrentProcess().PrivateMemorySize64)}");
}


+ 49
- 6
src/TensorFlowNET.Console/MemoryTestingCases.cs View File

@@ -21,11 +21,7 @@ namespace Tensorflow
public Action<int> Constant2x3
=> (iterate) =>
{
var nd = np.array(new byte[,]
{
{1, 2, 3},
{4, 5, 6}
});
var nd = np.arange(1000).reshape(10, 100);
for (int i = 0; i < iterate; i++)
{
var tensor = tf.constant(nd);
@@ -38,7 +34,8 @@ namespace Tensorflow
{
for (int i = 0; i < iterate; i++)
{
var tensor = tf.Variable(3112.0f);
var nd = np.arange(128 * 128 * 3).reshape(128, 128, 3);
var variable = tf.Variable(nd);
}
};

@@ -66,5 +63,51 @@ namespace Tensorflow
var grad = tape.gradient(loss, w);
}
};

public Action<int> Conv2dWithVariable
=> (iterate) =>
{
for (int i = 0; i < iterate; i++)
{
var input = array_ops.zeros((10, 32, 32, 3), dtypes.float32);
var filter = tf.Variable(array_ops.zeros((3, 3, 3, 32), dtypes.float32));
var strides = new[] { 1, 1, 1, 1 };
var dilations = new[] { 1, 1, 1, 1 };

var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2D", null,
null,
input, filter,
"strides", strides,
"use_cudnn_on_gpu", true,
"padding", "VALID",
"explicit_paddings", new int[0],
"data_format", "NHWC",
"dilations", dilations);
}
};

public Action<int> Conv2dWithTensor
=> (iterate) =>
{
for (int i = 0; i < iterate; i++)
{
var input = array_ops.zeros((10, 32, 32, 3), dtypes.float32);
var filter = array_ops.zeros((3, 3, 3, 32), dtypes.float32);
var strides = new[] { 1, 1, 1, 1 };
var dilations = new[] { 1, 1, 1, 1 };

var results = tf.Runner.TFE_FastPathExecute(tf.Context, tf.Context.DeviceName,
"Conv2D", null,
null,
input, filter,
"strides", strides,
"use_cudnn_on_gpu", true,
"padding", "VALID",
"explicit_paddings", new int[0],
"data_format", "NHWC",
"dilations", dilations);
}
};
}
}

+ 10
- 4
src/TensorFlowNET.Console/Program.cs View File

@@ -12,20 +12,26 @@ namespace Tensorflow

// boot .net core 10.5M.
var mm = new MemoryMonitor();
// warm up tensorflow.net 28.5M.
// warm up tensorflow.net 37.3M.
mm.WarmUp();
var cases = new MemoryTestingCases();

int batchSize = 1000;

// 1 million tensor
mm.Execute(10, 100 * batchSize, cases.Constant);

// explaination of constant
mm.Execute(10, 100 * batchSize, cases.Constant2x3);

// 1 million float tensor 68M.
mm.Execute(10, 100 * batchSize, cases.Constant);
// +0M
mm.Execute(10, batchSize, cases.Conv2dWithTensor);

// 100K float variable 84M.
mm.Execute(10, 10 * batchSize, cases.Variable);
mm.Execute(10, batchSize, cases.Variable);

// +45M memory leak
mm.Execute(10, batchSize, cases.Conv2dWithVariable);

// 1 million math add 39M.
mm.Execute(10, 100 * batchSize, cases.MathAdd);


+ 4
- 0
src/TensorFlowNET.Console/Tensorflow.Console.csproj View File

@@ -8,6 +8,10 @@
<Platforms>AnyCPU;x64</Platforms>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<DefineConstants>TRACE;DEBUG</DefineConstants>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="SciSharp.TensorFlow.Redist" Version="2.3.1" />
</ItemGroup>


+ 21
- 22
src/TensorFlowNET.Core/DisposableObject.cs View File

@@ -22,6 +22,7 @@ namespace Tensorflow
{
/// <summary>
/// Abstract class for disposable object allocated in unmanaged runtime.
/// https://docs.microsoft.com/en-us/dotnet/api/system.idisposable.dispose?redirectedfrom=MSDN&view=net-5.0#System_IDisposable_Dispose
/// </summary>
public abstract class DisposableObject : IDisposable
{
@@ -36,24 +37,31 @@ namespace Tensorflow
=> _handle = handle;

[SuppressMessage("ReSharper", "InvertIf")]
private void internal_dispose(bool disposing)
private void Dispose(bool disposing)
{
if (_disposed)
return;

_disposed = true;

//first handle managed, they might use the unmanaged resources.
if (disposing)
{
// dispose managed state (managed objects).
DisposeManagedResources();
}

//free unmanaged memory
// free unmanaged memory
if (_handle != IntPtr.Zero)
{
// Call the appropriate methods to clean up
// unmanaged resources here.
// If disposing is false,
// only the following code is executed.
DisposeUnmanagedResources(_handle);
_handle = IntPtr.Zero;
}

// Note disposing has been done.
_disposed = true;
}

/// <summary>
@@ -68,29 +76,20 @@ namespace Tensorflow
/// </summary>
protected abstract void DisposeUnmanagedResources(IntPtr handle);

~DisposableObject()
{
internal_dispose(false);
}

public void Dispose()
{
lock (this)
{
internal_dispose(true);
GC.SuppressFinalize(this);
}
Dispose(true);
// This object will be cleaned up by the Dispose method.
// Therefore, you should call GC.SupressFinalize to
// take this object off the finalization queue
// and prevent finalization code for this object
// from executing a second time.
GC.SuppressFinalize(this);
}

/// <summary>
/// If <see cref="_handle"/> is <see cref="IntPtr.Zero"/> then throws <see cref="ObjectDisposedException"/>
/// </summary>
/// <exception cref="ObjectDisposedException">When <see cref="_handle"/> is <see cref="IntPtr.Zero"/></exception>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
protected void EnsureNotDisposed()
~DisposableObject()
{
if (_disposed)
throw new ObjectDisposedException($"Unable to access disposed object, Type: {GetType().Name}");
Dispose(false);
}
}
}

+ 6
- 10
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_Execute.cs View File

@@ -14,6 +14,7 @@
limitations under the License.
******************************************************************************/

using System;
using System.Linq;
using Tensorflow.Contexts;
using static Tensorflow.Binding;
@@ -48,16 +49,11 @@ namespace Tensorflow.Eager
{
for (int i = 0; i < inputs.Length; ++i)
{
SafeTensorHandleHandle tensor_handle;
switch (inputs[i])
SafeTensorHandleHandle tensor_handle = inputs[i] switch
{
case EagerTensor et:
tensor_handle = et.EagerTensorHandle;
break;
default:
tensor_handle = c_api.TFE_NewTensorHandle(inputs[i], status.Handle);
break;
}
EagerTensor et => et.EagerTensorHandle,
_ => throw new NotImplementedException("")
};
c_api.TFE_OpAddInput(op, tensor_handle, status.Handle);
status.Check(true);
}
@@ -71,7 +67,7 @@ namespace Tensorflow.Eager
c_api.TFE_Execute(op, outputs, out num_outputs, status.Handle);
status.Check(true);
}
return outputs.Select(x => new EagerTensor(x)).ToArray();
return outputs.Select(x => new EagerTensor(x, op)).ToArray();
}
}
}

+ 8
- 8
src/TensorFlowNET.Core/Eager/EagerRunner.TFE_FastPathExecute.cs View File

@@ -158,7 +158,7 @@ namespace Tensorflow.Eager
c_api.TFE_Execute(op, retVals, out num_retvals, status.Handle);
status.Check(true);

var flat_result = retVals.Select(x => new EagerTensor(x)).ToArray();
var flat_result = retVals.Select(x => new EagerTensor(x, op)).ToArray();

if (op_exec_info.run_callbacks)
{
@@ -182,7 +182,11 @@ namespace Tensorflow.Eager

status.Check(true);
return op;*/
return c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
var op = c_api.TFE_NewOp(ctx.Handle, op_or_function_name, status.Handle);
#if TRACK_TENSOR_LIFE
print($"New OpHandle 0x{op.DangerousGetHandle().ToString("x16")}");
#endif
return op;
}

bool HasAccumulator()
@@ -219,22 +223,18 @@ namespace Tensorflow.Eager
SafeOpHandle op,
Status status)
{
SafeTensorHandleHandle input_handle;

// ConvertToTensor();
var tensor = tf.convert_to_tensor(inputs);
input_handle = tensor.EagerTensorHandle;
flattened_inputs.Add(tensor);

if (add_type_attr && !string.IsNullOrEmpty(input_arg.TypeAttr))
{
var dtype = c_api.TFE_TensorHandleDataType(input_handle);
var dtype = c_api.TFE_TensorHandleDataType(tensor.EagerTensorHandle);
c_api.TFE_OpSetAttrType(op, input_arg.TypeAttr, dtype);
flattened_attrs.Add(input_arg.TypeAttr);
flattened_attrs.Add(dtype);
}

c_api.TFE_OpAddInput(op, input_handle, status.Handle);
c_api.TFE_OpAddInput(op, tensor.EagerTensorHandle, status.Handle);
status.Check(true);

return true;


+ 22
- 30
src/TensorFlowNET.Core/Eager/EagerTensor.Creation.cs View File

@@ -6,36 +6,36 @@ namespace Tensorflow.Eager
{
public partial class EagerTensor
{
public EagerTensor() : base(IntPtr.Zero)
{

}
SafeOpHandle _opHandle;

public EagerTensor(SafeTensorHandleHandle handle) : base(IntPtr.Zero)
public EagerTensor(SafeTensorHandleHandle handle, SafeOpHandle opHandle) : base(IntPtr.Zero)
{
_opHandle = opHandle;
EagerTensorHandle = handle;
Resolve();
}

public EagerTensor(string value, string device_name) : base(value)
{
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
Resolve();
SetEagerTensorHandleAndResolve();
}

public EagerTensor(byte[] value, string device_name, TF_DataType dtype) : base(value, dType: dtype)
{
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
Resolve();
SetEagerTensorHandleAndResolve();
}

public EagerTensor(string[] value, string device_name) : base(value)
{
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
Resolve();
SetEagerTensorHandleAndResolve();
}

public EagerTensor(NDArray value, string device_name) : base(value)
{
SetEagerTensorHandleAndResolve();
}

void SetEagerTensorHandleAndResolve()
{
EagerTensorHandle = c_api.TFE_NewTensorHandle(_handle, tf.Status.Handle);
Resolve();
@@ -47,10 +47,10 @@ namespace Tensorflow.Eager

if (_handle == IntPtr.Zero)
_handle = c_api.TFE_TensorHandleResolve(EagerTensorHandle, tf.Status.Handle);
// print($"New TensorHandle {Id} 0x{_handle.ToString("x16")}");
// print($"New EagerTensorHandle {Id} {EagerTensorHandle}");
#if TRACK_TENSOR_LIFE
print($"New TensorHandle {Id} 0x{_handle.ToString("x16")}");
print($"New EagerTensorHandle {Id} {EagerTensorHandle}");
#endif
return this;
}

@@ -60,23 +60,14 @@ namespace Tensorflow.Eager
/// <returns></returns>
public Tensor AsPlaceholder(string name = null)
{
Tensor placeholder = null;
tf_with(ops.control_dependencies(null), delegate
{
placeholder = tf.placeholder(dtype, name: name);
});
var placeholder = tf_with(ops.control_dependencies(null), _ => tf.placeholder(dtype, name: name));
copy_handle_data(placeholder);
return placeholder;
}

public Tensor AsConstant(string name = null)
{
Tensor constant = null;
tf_with(ops.control_dependencies(null), delegate
{
constant = tf.constant(numpy(), name: name);
});
return constant;
return tf_with(ops.control_dependencies(null), _ => tf.constant(numpy(), name: name));
}

void copy_handle_data(Tensor target_t)
@@ -95,15 +86,16 @@ namespace Tensorflow.Eager
protected override void DisposeManagedResources()
{
base.DisposeManagedResources();

// print($"Delete EagerTensorHandle {Id} {EagerTensorHandle}");
EagerTensorHandle.Dispose();
}

protected override void DisposeUnmanagedResources(IntPtr handle)
{
base.DisposeUnmanagedResources(handle);
// print($"Delete TensorHandle {Id} 0x{_handle.ToString("x16")}");
EagerTensorHandle.Dispose();

if (_opHandle != null)
_opHandle.Dispose();
}
}
}

+ 8
- 0
src/TensorFlowNET.Core/Eager/SafeOpHandle.cs View File

@@ -16,6 +16,7 @@

using System;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -23,15 +24,22 @@ namespace Tensorflow.Eager
{
private SafeOpHandle()
{

}

public SafeOpHandle(IntPtr handle)
: base(handle)
{
#if TRACK_TENSOR_LIFE
print($"Get OpHandle 0x{handle.ToString("x16")}");
#endif
}

protected override bool ReleaseHandle()
{
#if TRACK_TENSOR_LIFE
print($"Delete OpHandle 0x{handle.ToString("x16")}");
#endif
c_api.TFE_DeleteOp(handle);
SetHandle(IntPtr.Zero);
return true;


+ 4
- 0
src/TensorFlowNET.Core/Eager/SafeTensorHandleHandle.cs View File

@@ -16,6 +16,7 @@

using System;
using Tensorflow.Util;
using static Tensorflow.Binding;

namespace Tensorflow.Eager
{
@@ -32,6 +33,9 @@ namespace Tensorflow.Eager

protected override bool ReleaseHandle()
{
#if TRACK_TENSOR_LIFE
print($"Delete EagerTensorHandle 0x{handle.ToString("x16")}");
#endif
c_api.TFE_DeleteTensorHandle(handle);
SetHandle(IntPtr.Zero);
return true;


+ 1
- 1
src/TensorFlowNET.Core/Tensorflow.Binding.csproj View File

@@ -38,7 +38,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works</PackageReleaseN

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<DefineConstants>TRACE;DEBUG</DefineConstants>
<DefineConstants>TRACE;DEBUG;TRACK_TENSOR_LIFE_1</DefineConstants>
<PlatformTarget>AnyCPU</PlatformTarget>
</PropertyGroup>



+ 10
- 13
src/TensorFlowNET.Core/Tensors/Tensor.Creation.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using NumSharp;
using NumSharp.Backends.Unmanaged;
using System;
using System.Diagnostics.CodeAnalysis;
using System.Linq;
@@ -30,11 +31,6 @@ namespace Tensorflow
[SuppressMessage("ReSharper", "InvokeAsExtensionMethod")]
public partial class Tensor
{
/// <summary>
/// When Tensor was created from an object that is managed by C#'s GC - this will hold reference to prevent it from being collected.
/// </summary>
protected object AllocationReferenceHolder;

/// <summary>
/// The handle that was used to allocate this tensor, dependent on <see cref="AllocationType"/>.
/// </summary>
@@ -545,33 +541,34 @@ namespace Tensorflow
return;
}

_handle = CreateTensorFromNDArray(nd, tensorDType);
CreateTensorFromNDArray(nd, tensorDType);
}

private unsafe IntPtr CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
private unsafe void CreateTensorFromNDArray(NDArray nd, TF_DataType? given_dtype)
{
if (nd.typecode == NPTypeCode.String)
throw new NotImplementedException("Support for NDArray of type string not implemented yet");

var arraySlice = nd.Unsafe.Storage.Shape.IsContiguous ? nd.GetData() : nd.CloneData();

var handle = TF_NewTensor(
_handle = TF_NewTensor(
given_dtype ?? nd.dtype.as_dtype(),
dims: nd.shape.Select(i => (long)i).ToArray(),
num_dims: nd.ndim,
data: arraySlice.Address,
len: (ulong)(nd.size * nd.dtypesize));

//if TF decided not to perform copy, hold reference for given NDArray.
if (TF_TensorData(handle).ToPointer() == arraySlice.Address)
// if TF decided not to perform copy, hold reference for given NDArray.
if (TensorDataPointer.ToPointer() == arraySlice.Address)
{
AllocationType = AllocationType.FromPointer;
AllocationReferenceHolder = arraySlice;
AllocationHandle = arraySlice;
#if TRACK_TENSOR_LIFE
print($"New Tensor {Id} {AllocationType} 0x{TensorDataPointer.ToString("x16")}");
#endif
}
else
AllocationType = AllocationType.Tensorflow;

return handle;
}

public Tensor(Operation op, int value_index, TF_DataType dtype)


+ 31
- 17
src/TensorFlowNET.Core/Tensors/Tensor.cs View File

@@ -15,6 +15,7 @@
******************************************************************************/

using NumSharp;
using NumSharp.Backends.Unmanaged;
using System;
using System.Diagnostics.CodeAnalysis;
using System.Globalization;
@@ -253,30 +254,43 @@ namespace Tensorflow
/// <remarks>Equivalent to what you would perform inside <see cref="DisposableObject.Dispose"/></remarks>
protected override void DisposeManagedResources()
{
AllocationReferenceHolder = null;
}

[SuppressMessage("ReSharper", "ConvertIfStatementToSwitchStatement")]
protected override void DisposeUnmanagedResources(IntPtr handle)
{
c_api.TF_DeleteTensor(handle);
if (AllocationHandle == null)
return;

if (AllocationType == AllocationType.GCHandle)
{
((GCHandle)AllocationHandle).Free();
AllocationHandle = null;
AllocationType = AllocationType.None;
}
else if (AllocationType == AllocationType.Marshal)
if (AllocationHandle != null)
{
Marshal.FreeHGlobal((IntPtr)AllocationHandle);
AllocationHandle = null;
AllocationType = AllocationType.None;

#if TRACK_TENSOR_LIFE
print($"Delete AllocationHandle.{AllocationType} 0x{TensorDataPointer.ToString("x16")}");
#endif
if (AllocationType == AllocationType.GCHandle)
{
((GCHandle)AllocationHandle).Free();
AllocationHandle = null;
AllocationType = AllocationType.None;
}
else if (AllocationType == AllocationType.Marshal)
{
Marshal.FreeHGlobal((IntPtr)AllocationHandle);
AllocationHandle = null;
AllocationType = AllocationType.None;
}
else if (AllocationType == AllocationType.FromPointer)
{
AllocationHandle = null;
AllocationType = AllocationType.None;
}
else
throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType}).");
}
else
throw new InvalidOperationException($"Tensor.AllocationHandle is not null ({AllocationHandle}) but AllocationType is not matched to a C# allocation type ({AllocationType}).");

#if TRACK_TENSOR_LIFE
print($"Delete TensorHandle 0x{handle.ToString("x16")}");
#endif
c_api.TF_DeleteTensor(handle);
}

public virtual IntPtr ToPointer()


+ 7
- 1
src/TensorFlowNET.Core/Tensors/Tensors.cs View File

@@ -13,7 +13,7 @@ namespace Tensorflow
/// and Tensor[] from Tensors implicitily.
/// It works for tuple and scalar as well.
/// </summary>
public class Tensors : IEnumerable<Tensor>
public class Tensors : IEnumerable<Tensor>, IDisposable
{
List<Tensor> items = new List<Tensor>();

@@ -90,5 +90,11 @@ namespace Tensorflow
=> items.Count() == 1
? items.First().ToString()
: items.Count() + " Tensors" + ". " + string.Join(", ", items.Select(x => x.name));

public void Dispose()
{
foreach (var item in items)
item.Dispose();
}
}
}

+ 1
- 1
src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs View File

@@ -43,7 +43,7 @@ namespace Tensorflow
public Operation Initializer => initializer_op;
public Operation Op => handle.op;
public Graph Graph => handle.graph;
public string Device => "";
public string Device => handle.Device;

public BaseResourceVariable()
{


+ 0
- 15
src/TensorFlowNET.Core/Variables/ResourceVariable.cs View File

@@ -26,18 +26,6 @@ namespace Tensorflow
/// </summary>
public partial class ResourceVariable : BaseResourceVariable, IVariableV1
{
Tensor _cached_value;
public string Device => handle.Device;
#pragma warning disable CS0108 // Member hides inherited member; missing new keyword
public Graph Graph => handle.graph;
#pragma warning restore CS0108 // Member hides inherited member; missing new keyword
public Operation op => handle.op;
public Tensor is_initialized_op { get; set; }

public ResourceVariable(IntPtr handle, IntPtr tensor) : base(handle, tensor)
{
}

public ResourceVariable(object initial_value = null,
bool trainable = true,
List<string> collections = null,
@@ -150,7 +138,6 @@ namespace Tensorflow
graph_mode: _in_graph_mode);

gen_resource_variable_ops.assign_variable_op(handle, _initial_value);
is_initialized_op = null;
initializer_op = null;
_graph_element = null;
_dtype = _initial_value.dtype.as_base_dtype();
@@ -199,8 +186,6 @@ namespace Tensorflow
{
prepend_name_scope = ops.prepend_name_scope(variable_def.SnapshotName, import_scope: import_scope);
var snapshot = g.as_graph_element(prepend_name_scope) as Tensor;
if (snapshot.op.type != "ReadVariableOp")
_cached_value = snapshot;
while (snapshot.op.type != "ReadVariableOp")
snapshot = snapshot.op.inputs[0];
_graph_element = snapshot;


+ 1
- 1
src/TensorFlowNET.Core/Variables/_UnreadVariable.cs View File

@@ -11,7 +11,7 @@ namespace Tensorflow
public override string Name => _in_graph_mode ? _parent_op.name : "UnreadVariable";

public _UnreadVariable(Tensor handle, TF_DataType dtype, TensorShape shape,
bool in_graph_mode, string unique_id) : base()
bool in_graph_mode, string unique_id)
{
_dtype = dtype;
_shape = shape;


Loading…
Cancel
Save