From 62309fcd8e79210f5bb3eedc1e815b018335019d Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Mon, 27 Feb 2023 21:35:38 +0800 Subject: [PATCH] Add CheckpointReader and corresponding C APIs. --- .../Checkpoint/CheckpointReader.cs | 94 +++++++++++++++++++ .../Checkpoint/c_api.checkpoint.cs | 29 ++++++ 2 files changed, 123 insertions(+) create mode 100644 src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs create mode 100644 src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs new file mode 100644 index 00000000..49976280 --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -0,0 +1,94 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Tensorflow.Checkpoint +{ + internal class CheckpointReader : IDisposable + { + private IntPtr _reader; + public Dictionary VariableToDataTypeMap { get; set; } + public Dictionary VariableToShapeMap { get; set; } + + public CheckpointReader(string filename) + { + Status status = new Status(); + _reader = c_api.TF_NewCheckpointReader(filename, status.Handle); + status.Check(true); + ReadAllShapeAndType(); + } + + public int HasTensor(string name) + { + return c_api.TF_CheckpointReaderHasTensor(_reader, name); + } + + /// + /// Get the variable name. + /// + /// + /// + public string GetVariable(int index) + { + return c_api.TF_CheckpointReaderGetVariable(_reader, index); + } + + public int Size() + { + return c_api.TF_CheckpointReaderSize(_reader); + } + + public TF_DataType GetVariableDataType(string name) + { + return c_api.TF_CheckpointReaderGetVariableDataType(_reader, name); + } + + public Shape GetVariableShape(string name) + { + // TODO(Rinne): Change it to a constant. + int num_dims = GetVariableNumDims(name); + long[] dims = new long[num_dims]; + Status status = new Status(); + c_api.TF_CheckpointReaderGetVariableShape(_reader, name, dims, num_dims, status.Handle); + status.Check(true); + return new Shape(dims); + } + + public int GetVariableNumDims(string name) + { + return c_api.TF_CheckpointReaderGetVariableNumDims(_reader, name); + } + + public Tensor GetTensor(string name) + { + Status status = new Status(); + var tensor = c_api.TF_CheckpointReaderGetTensor(_reader, name, status.Handle); + status.Check(true); + var shape = GetVariableShape(name); + var dtype = GetVariableDataType(name); + return new Tensor(tensor, shape, dtype); + } + + private void ReadAllShapeAndType() + { + VariableToDataTypeMap = new Dictionary(); + VariableToShapeMap = new Dictionary(); + int size = Size(); + for(int i = 0; i < size; i++) + { + var name = GetVariable(i); + var shape = GetVariableShape(name); + var dtype = GetVariableDataType(name); + VariableToDataTypeMap[name] = dtype; + VariableToShapeMap[name] = shape; + } + } + + public void Dispose() + { + c_api.TF_DeleteCheckpointReader(_reader); + } + } +} diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs new file mode 100644 index 00000000..2132cd1d --- /dev/null +++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Runtime.InteropServices; + +namespace Tensorflow +{ + public unsafe partial class c_api + { + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern void TF_DeleteCheckpointReader(IntPtr reader); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderSize(IntPtr reader); + [DllImport(TensorFlowLibName)] + internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status); + [DllImport(TensorFlowLibName)] + internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); + [DllImport(TensorFlowLibName)] + internal static extern IntPtr TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); + } +}