From 832bf7dbe08953a322e1b08c2a7b650ab50a8c0e Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 9 Sep 2023 01:30:35 +0100 Subject: [PATCH] Simplified implementation of `GetState` and fixed a memory leak (`bigMemory` was never freed) --- LLama/LLamaContext.cs | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 1a1845b0..66314bfe 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -195,32 +195,24 @@ namespace LLama unsafe { - var bigMemory = Marshal.AllocHGlobal((nint)stateSize); - var smallMemory = IntPtr.Zero; + // Allocate a chunk of memory large enough to hold the entire state + var memory = Marshal.AllocHGlobal((nint)stateSize); try { - // Copy the state data into "big memory", discover the actual size required - var actualSize = _ctx.GetState(bigMemory, stateSize); + // Copy the state data into memory, discover the actual size required + var actualSize = _ctx.GetState(memory, stateSize); - // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying - if (actualSize >= stateSize - 1_000_000) - return new State(bigMemory); + // Shrink to size + memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); - // Allocate a smaller buffer which is exactly the right size - smallMemory = Marshal.AllocHGlobal((nint)actualSize); - - // Copy into the smaller buffer and free the large one to save excess memory usage - Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); - - return new State(smallMemory); + // Wrap memory in a state and return it + memory = IntPtr.Zero; + return new State(memory); } - catch + finally { - if (bigMemory != IntPtr.Zero) - Marshal.FreeHGlobal(bigMemory); - if (smallMemory != IntPtr.Zero) - Marshal.FreeHGlobal(smallMemory); - throw; + if (memory != IntPtr.Zero) + Marshal.FreeHGlobal(memory); } } }