Browse Source

Add CheckpointReader and corresponding C APIs.

pull/989/head
Yaohui Liu 2 years ago
parent
commit
62309fcd8e
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
2 changed files with 123 additions and 0 deletions
  1. +94
    -0
      src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs
  2. +29
    -0
      src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs

+ 94
- 0
src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs View File

@@ -0,0 +1,94 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;

namespace Tensorflow.Checkpoint
{
internal class CheckpointReader : IDisposable
{
private IntPtr _reader;
public Dictionary<string, TF_DataType> VariableToDataTypeMap { get; set; }
public Dictionary<string, Shape> VariableToShapeMap { get; set; }

public CheckpointReader(string filename)
{
Status status = new Status();
_reader = c_api.TF_NewCheckpointReader(filename, status.Handle);
status.Check(true);
ReadAllShapeAndType();
}

public int HasTensor(string name)
{
return c_api.TF_CheckpointReaderHasTensor(_reader, name);
}

/// <summary>
/// Get the variable name.
/// </summary>
/// <param name="index"></param>
/// <returns></returns>
public string GetVariable(int index)
{
return c_api.TF_CheckpointReaderGetVariable(_reader, index);
}

public int Size()
{
return c_api.TF_CheckpointReaderSize(_reader);
}

public TF_DataType GetVariableDataType(string name)
{
return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name);
}

public Shape GetVariableShape(string name)
{
// TODO(Rinne): Change it to a constant.
int num_dims = GetVariableNumDims(name);
long[] dims = new long[num_dims];
Status status = new Status();
c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle);
status.Check(true);
return new Shape(dims);
}

public int GetVariableNumDims(string name)
{
return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name);
}

public Tensor GetTensor(string name)
{
Status status = new Status();
var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle);
status.Check(true);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
return new Tensor(tensor, shape, dtype);
}

private void ReadAllShapeAndType()
{
VariableToDataTypeMap = new Dictionary<string, TF_DataType>();
VariableToShapeMap = new Dictionary<string, Shape>();
int size = Size();
for(int i = 0; i < size; i++)
{
var name = GetVariable(i);
var shape = GetVariableShape(name);
var dtype = GetVariableDataType(name);
VariableToDataTypeMap[name] = dtype;
VariableToShapeMap[name] = shape;
}
}

public void Dispose()
{
c_api.TF_DeleteCheckpointReader(_reader);
}
}
}

+ 29
- 0
src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs View File

@@ -0,0 +1,29 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Runtime.InteropServices;

namespace Tensorflow
{
public unsafe partial class c_api
{
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern void TF_DeleteCheckpointReader(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderSize(IntPtr reader);
[DllImport(TensorFlowLibName)]
internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status);
[DllImport(TensorFlowLibName)]
internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name);
[DllImport(TensorFlowLibName)]
internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status);
}
}

Loading…
Cancel
Save