Browse Source

Fix the stucking of training when loading model.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
3dafbf04e9
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
2 changed files with 39 additions and 45 deletions
  1. +29
    -33
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +10
    -12
      src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs

+ 29
- 33
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -1,29 +1,42 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
using Tensorflow.Util;
using Tensorflow.Util;


namespace Tensorflow.Checkpoint namespace Tensorflow.Checkpoint
{ {
public class CheckpointReader : SafeTensorflowHandle
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle
{ {
public SafeCheckpointReaderHandle(): base()
{
}
public SafeCheckpointReaderHandle(IntPtr handle): base(handle)
{

}

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
SetHandle(IntPtr.Zero);
return true;
}
}
public class CheckpointReader
{
private SafeCheckpointReaderHandle _handle;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; } public Dictionary<string, Shape> VariableToShapeMap { get; set; }


public CheckpointReader(string filename) public CheckpointReader(string filename)
{ {
Status status = new Status(); Status status = new Status();
handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true); status.Check(true);
ReadAllShapeAndType(); ReadAllShapeAndType();
} }


public int HasTensor(string name) public int HasTensor(string name)
{ {
return c_api.TF_CheckpointReaderHasTensor(handle, name);
return c_api.TF_CheckpointReaderHasTensor(_handle, name);
} }


/// <summary> /// <summary>
@@ -33,45 +46,39 @@ namespace Tensorflow.Checkpoint
/// <returns></returns> /// <returns></returns>
public string GetVariable(int index) public string GetVariable(int index)
{ {
return c_api.TF_CheckpointReaderGetVariable(handle, index);
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index));
} }


public int Size() public int Size()
{ {
return c_api.TF_CheckpointReaderSize(handle);
return c_api.TF_CheckpointReaderSize(_handle);
} }


public TF_DataType GetVariableDataType(string name) public TF_DataType GetVariableDataType(string name)
{ {
return c_api.TF_CheckpointReaderGetVariableDataType(handle, name);
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name);
} }


public Shape GetVariableShape(string name) public Shape GetVariableShape(string name)
{ {
// TODO(Rinne): Change it to a constant.
int num_dims = GetVariableNumDims(name); int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims]; long[] dims = new long[num_dims];
Status status = new Status(); Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle);
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle);
status.Check(true); status.Check(true);
return new Shape(dims); return new Shape(dims);
} }


public int GetVariableNumDims(string name) public int GetVariableNumDims(string name)
{ {
return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name);
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name);
} }


public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid)
{ {
Status status = new Status(); Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle);
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle);
status.Check(true); status.Check(true);
var shape = GetVariableShape(name);
if(dtype == TF_DataType.DtInvalid)
{
dtype = GetVariableDataType(name);
}
return new Tensor(tensor); return new Tensor(tensor);
} }


@@ -89,16 +96,5 @@ namespace Tensorflow.Checkpoint
VariableToShapeMap[name] = shape; VariableToShapeMap[name] = shape;
} }
} }

protected override bool ReleaseHandle()
{
c_api.TF_DeleteCheckpointReader(handle);
return true;
}

public void Dispose()
{
c_api.TF_DeleteCheckpointReader(handle);
}
} }
} }

+ 10
- 12
src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs View File

@@ -1,29 +1,27 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;
using Tensorflow.Checkpoint;


namespace Tensorflow namespace Tensorflow
{ {
public unsafe partial class c_api public unsafe partial class c_api
{ {
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status);
internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern void TF_DeleteCheckpointReader(IntPtr reader); internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name);
internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index);
internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderSize(IntPtr reader);
internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name);
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name);
internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name);
[DllImport(TensorFlowLibName)] [DllImport(TensorFlowLibName)]
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status);
} }
} }

Loading…
Cancel
Save