|
|
|
@@ -1,29 +1,44 @@ |
|
|
|
using System; |
|
|
|
using System.Collections.Generic; |
|
|
|
using System.IO; |
|
|
|
using System.Linq; |
|
|
|
using System.Runtime.InteropServices; |
|
|
|
using System.Text; |
|
|
|
using Tensorflow.Util; |
|
|
|
using Tensorflow.Util; |
|
|
|
|
|
|
|
namespace Tensorflow.Checkpoint |
|
|
|
{ |
|
|
|
public class CheckpointReader : SafeTensorflowHandle |
|
|
|
sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle |
|
|
|
{ |
|
|
|
public SafeCheckpointReaderHandle(): base() |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
public SafeCheckpointReaderHandle(IntPtr handle): base(handle) |
|
|
|
{ |
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
protected override bool ReleaseHandle() |
|
|
|
{ |
|
|
|
//if (handle != IntPtr.Zero) |
|
|
|
//{ |
|
|
|
// c_api.TF_DeleteCheckpointReader(this); |
|
|
|
//} |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
public class CheckpointReader |
|
|
|
{ |
|
|
|
private SafeCheckpointReaderHandle _handle; |
|
|
|
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; } |
|
|
|
public Dictionary<string, Shape> VariableToShapeMap { get; set; } |
|
|
|
|
|
|
|
public CheckpointReader(string filename) |
|
|
|
{ |
|
|
|
Status status = new Status(); |
|
|
|
handle = c_api.TF_NewCheckpointReader(filename, status.Handle); |
|
|
|
_handle = c_api.TF_NewCheckpointReader(filename, status.Handle); |
|
|
|
status.Check(true); |
|
|
|
ReadAllShapeAndType(); |
|
|
|
} |
|
|
|
|
|
|
|
public int HasTensor(string name) |
|
|
|
{ |
|
|
|
return c_api.TF_CheckpointReaderHasTensor(handle, name); |
|
|
|
return c_api.TF_CheckpointReaderHasTensor(_handle, name); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
@@ -33,45 +48,39 @@ namespace Tensorflow.Checkpoint |
|
|
|
/// <returns></returns> |
|
|
|
public string GetVariable(int index) |
|
|
|
{ |
|
|
|
return c_api.TF_CheckpointReaderGetVariable(handle, index); |
|
|
|
return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); |
|
|
|
} |
|
|
|
|
|
|
|
public int Size() |
|
|
|
{ |
|
|
|
return c_api.TF_CheckpointReaderSize(handle); |
|
|
|
return c_api.TF_CheckpointReaderSize(_handle); |
|
|
|
} |
|
|
|
|
|
|
|
public TF_DataType GetVariableDataType(string name) |
|
|
|
{ |
|
|
|
return c_api.TF_CheckpointReaderGetVariableDataType(handle, name); |
|
|
|
return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); |
|
|
|
} |
|
|
|
|
|
|
|
public Shape GetVariableShape(string name) |
|
|
|
{ |
|
|
|
// TODO(Rinne): Change it to a constant. |
|
|
|
int num_dims = GetVariableNumDims(name); |
|
|
|
long[] dims = new long[num_dims]; |
|
|
|
Status status = new Status(); |
|
|
|
c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle); |
|
|
|
c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); |
|
|
|
status.Check(true); |
|
|
|
return new Shape(dims); |
|
|
|
} |
|
|
|
|
|
|
|
public int GetVariableNumDims(string name) |
|
|
|
{ |
|
|
|
return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name); |
|
|
|
return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); |
|
|
|
} |
|
|
|
|
|
|
|
public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) |
|
|
|
{ |
|
|
|
Status status = new Status(); |
|
|
|
var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle); |
|
|
|
var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); |
|
|
|
status.Check(true); |
|
|
|
var shape = GetVariableShape(name); |
|
|
|
if(dtype == TF_DataType.DtInvalid) |
|
|
|
{ |
|
|
|
dtype = GetVariableDataType(name); |
|
|
|
} |
|
|
|
return new Tensor(tensor); |
|
|
|
} |
|
|
|
|
|
|
|
@@ -89,16 +98,5 @@ namespace Tensorflow.Checkpoint |
|
|
|
VariableToShapeMap[name] = shape; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
protected override bool ReleaseHandle() |
|
|
|
{ |
|
|
|
c_api.TF_DeleteCheckpointReader(handle); |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
public void Dispose() |
|
|
|
{ |
|
|
|
c_api.TF_DeleteCheckpointReader(handle); |
|
|
|
} |
|
|
|
} |
|
|
|
} |