diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index d2fb0fe3..5b3853fc 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -195,32 +195,24 @@ namespace LLama unsafe { - var bigMemory = Marshal.AllocHGlobal((nint)stateSize); - var smallMemory = IntPtr.Zero; + // Allocate a chunk of memory large enough to hold the entire state + var memory = Marshal.AllocHGlobal((nint)stateSize); try { - // Copy the state data into "big memory", discover the actual size required - var actualSize = _ctx.GetState(bigMemory, stateSize); + // Copy the state data into memory, discover the actual size required + var actualSize = _ctx.GetState(memory, stateSize); - // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying - if (actualSize >= stateSize - 1_000_000) - return new State(bigMemory); + // Shrink to size + memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); - // Allocate a smaller buffer which is exactly the right size - smallMemory = Marshal.AllocHGlobal((nint)actualSize); - - // Copy into the smaller buffer and free the large one to save excess memory usage - Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); - - return new State(smallMemory); + // Wrap memory in a state and return it + memory = IntPtr.Zero; + return new State(memory); } - catch + finally { - if (bigMemory != IntPtr.Zero) - Marshal.FreeHGlobal(bigMemory); - if (smallMemory != IntPtr.Zero) - Marshal.FreeHGlobal(smallMemory); - throw; + if (memory != IntPtr.Zero) + Marshal.FreeHGlobal(memory); } } }