From 3dafbf04e90b54b418e831389d6233319d80dd45 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Fri, 3 Mar 2023 11:22:31 +0800 Subject: [PATCH] Fix the stucking of training when loading model. --- .../Checkpoint/CheckpointReader.cs | 62 +++++++++---------- .../Checkpoint/c_api.checkpoint.cs | 22 +++---- 2 files changed, 39 insertions(+), 45 deletions(-) diff --git a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs index 2a8e2382..0cc8e5fb 100644 --- a/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs +++ b/src/TensorFlowNET.Core/Checkpoint/CheckpointReader.cs @@ -1,29 +1,42 @@ -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Runtime.InteropServices; -using System.Text; -using Tensorflow.Util; +using Tensorflow.Util; namespace Tensorflow.Checkpoint { - public class CheckpointReader : SafeTensorflowHandle + sealed class SafeCheckpointReaderHandle : SafeTensorflowHandle { + public SafeCheckpointReaderHandle(): base() + { + + } + public SafeCheckpointReaderHandle(IntPtr handle): base(handle) + { + + } + + protected override bool ReleaseHandle() + { + c_api.TF_DeleteCheckpointReader(handle); + SetHandle(IntPtr.Zero); + return true; + } + } + public class CheckpointReader + { + private SafeCheckpointReaderHandle _handle; public Dictionary VariableToDataTypeMap { get; set; } public Dictionary VariableToShapeMap { get; set; } public CheckpointReader(string filename) { Status status = new Status(); - handle = c_api.TF_NewCheckpointReader(filename, status.Handle); + _handle = c_api.TF_NewCheckpointReader(filename, status.Handle); status.Check(true); ReadAllShapeAndType(); } public int HasTensor(string name) { - return c_api.TF_CheckpointReaderHasTensor(handle, name); + return c_api.TF_CheckpointReaderHasTensor(_handle, name); } /// @@ -33,45 +46,39 @@ namespace Tensorflow.Checkpoint /// public string GetVariable(int index) { - return c_api.TF_CheckpointReaderGetVariable(handle, index); + return c_api.StringPiece(c_api.TF_CheckpointReaderGetVariable(_handle, index)); } public int Size() { - return c_api.TF_CheckpointReaderSize(handle); + return c_api.TF_CheckpointReaderSize(_handle); } public TF_DataType GetVariableDataType(string name) { - return c_api.TF_CheckpointReaderGetVariableDataType(handle, name); + return c_api.TF_CheckpointReaderGetVariableDataType(_handle, name); } public Shape GetVariableShape(string name) { - // TODO(Rinne): Change it to a constant. int num_dims = GetVariableNumDims(name); long[] dims = new long[num_dims]; Status status = new Status(); - c_api.TF_CheckpointReaderGetVariableShape(handle, name, dims, num_dims, status.Handle); + c_api.TF_CheckpointReaderGetVariableShape(_handle, name, dims, num_dims, status.Handle); status.Check(true); return new Shape(dims); } public int GetVariableNumDims(string name) { - return c_api.TF_CheckpointReaderGetVariableNumDims(handle, name); + return c_api.TF_CheckpointReaderGetVariableNumDims(_handle, name); } public unsafe Tensor GetTensor(string name, TF_DataType dtype = TF_DataType.DtInvalid) { Status status = new Status(); - var tensor = c_api.TF_CheckpointReaderGetTensor(handle, name, status.Handle); + var tensor = c_api.TF_CheckpointReaderGetTensor(_handle, name, status.Handle); status.Check(true); - var shape = GetVariableShape(name); - if(dtype == TF_DataType.DtInvalid) - { - dtype = GetVariableDataType(name); - } return new Tensor(tensor); } @@ -89,16 +96,5 @@ namespace Tensorflow.Checkpoint VariableToShapeMap[name] = shape; } } - - protected override bool ReleaseHandle() - { - c_api.TF_DeleteCheckpointReader(handle); - return true; - } - - public void Dispose() - { - c_api.TF_DeleteCheckpointReader(handle); - } } } diff --git a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs index 8a6858f6..f956e333 100644 --- a/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/c_api.checkpoint.cs @@ -1,29 +1,27 @@ -using System; -using System.Collections.Generic; -using System.Text; -using System.Runtime.InteropServices; +using System.Runtime.InteropServices; +using Tensorflow.Checkpoint; namespace Tensorflow { public unsafe partial class c_api { [DllImport(TensorFlowLibName)] - internal static extern IntPtr TF_NewCheckpointReader(string filename, SafeStatusHandle status); + internal static extern SafeCheckpointReaderHandle TF_NewCheckpointReader(string filename, SafeStatusHandle status); [DllImport(TensorFlowLibName)] internal static extern void TF_DeleteCheckpointReader(IntPtr reader); [DllImport(TensorFlowLibName)] - internal static extern int TF_CheckpointReaderHasTensor(IntPtr reader, string name); + internal static extern int TF_CheckpointReaderHasTensor(SafeCheckpointReaderHandle reader, string name); [DllImport(TensorFlowLibName)] - internal static extern string TF_CheckpointReaderGetVariable(IntPtr reader, int index); + internal static extern IntPtr TF_CheckpointReaderGetVariable(SafeCheckpointReaderHandle reader, int index); [DllImport(TensorFlowLibName)] - internal static extern int TF_CheckpointReaderSize(IntPtr reader); + internal static extern int TF_CheckpointReaderSize(SafeCheckpointReaderHandle reader); [DllImport(TensorFlowLibName)] - internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(IntPtr reader, string name); + internal static extern TF_DataType TF_CheckpointReaderGetVariableDataType(SafeCheckpointReaderHandle reader, string name); [DllImport(TensorFlowLibName)] - internal static extern void TF_CheckpointReaderGetVariableShape(IntPtr reader, string name, long[] dims, int num_dims, SafeStatusHandle status); + internal static extern void TF_CheckpointReaderGetVariableShape(SafeCheckpointReaderHandle reader, string name, long[] dims, int num_dims, SafeStatusHandle status); [DllImport(TensorFlowLibName)] - internal static extern int TF_CheckpointReaderGetVariableNumDims(IntPtr reader, string name); + internal static extern int TF_CheckpointReaderGetVariableNumDims(SafeCheckpointReaderHandle reader, string name); [DllImport(TensorFlowLibName)] - internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(IntPtr reader, string name, SafeStatusHandle status); + internal static extern SafeTensorHandle TF_CheckpointReaderGetTensor(SafeCheckpointReaderHandle reader, string name, SafeStatusHandle status); } }