diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 6fc206ed..5023f815 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -26,5 +26,32 @@ namespace LLama.Unittest Assert.Equal(2048, _model.ContextSize); Assert.Equal(4096, _model.EmbeddingSize); } + + [Fact] + public void CloneContext() + { + var original = _model.CreateContext(_params); + + // Evaluate something (doesn't matter what, as long as it begins with token 1) + original.Eval(new[] { 1, 42, 321 }, 0); + + // Clone current state + var clone = original.Clone(); + + // Now evaluate something more + var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); + var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); + + // Assert that the context replied differently each time + Assert.NotEqual(reply1a, reply2a); + + // Give the same prompts to the cloned state + var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); + var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); + + // Assert that the cloned context replied in the same way as originally + Assert.Equal(reply1a, reply1b); + Assert.Equal(reply2a, reply2b); + } } } \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 9501d570..159c641c 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -19,7 +19,7 @@ namespace LLama /// /// A llama_context, which holds all the context required to interact with a model /// - public class LLamaContext + public sealed class LLamaContext : IDisposable { private readonly ILLamaLogger? _logger; @@ -111,15 +111,8 @@ namespace LLama public LLamaContext Clone() { using var pin = Params.ToLlamaContextParams(out var lparams); - - // Create a blank new context for the model - var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params); - - // Copy across the state - using var state = GetState(); - ctx.LoadState(state); - - return ctx; + var clone = _ctx.Clone(lparams); + return new LLamaContext(clone, Params); } /// @@ -197,7 +190,7 @@ namespace LLama /// public State GetState() { - var stateSize = NativeApi.llama_get_state_size(_ctx); + var stateSize = _ctx.GetStateSize(); unsafe { @@ -206,15 +199,17 @@ namespace LLama try { // Copy the state data into "big memory", discover the actual size required - var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); + var actualSize = _ctx.GetState(bigMemory, stateSize); + + // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying + if (actualSize >= stateSize - 1_000_000) + return new State(bigMemory); - // Allocate a smaller buffer + // Allocate a smaller buffer which is exactly the right size smallMemory = Marshal.AllocHGlobal((nint)actualSize); // Copy into the smaller buffer and free the large one to save excess memory usage Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); - Marshal.FreeHGlobal(bigMemory); - bigMemory = IntPtr.Zero; return new State(smallMemory); } @@ -274,7 +269,7 @@ namespace LLama { unsafe { - NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); + _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); } } @@ -498,10 +493,8 @@ namespace LLama } /// - public virtual void Dispose() + public void Dispose() { - GC.SuppressFinalize(this); - _ctx.Dispose(); } diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 57c305b2..5980d17c 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -8,7 +8,7 @@ namespace LLama /// /// The embedder for LLama, which supports getting embeddings from text. /// - public class LLamaEmbedder + public sealed class LLamaEmbedder : IDisposable { private readonly LLamaContext _ctx; diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 2e499196..86c1c71c 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -89,6 +90,35 @@ namespace LLama.Native return new(ctx_ptr, model); } + + /// + /// Create a new llama context with a clone of the current llama context state + /// + /// + /// + public SafeLLamaContextHandle Clone(LLamaContextParams lparams) + { + // Allocate space to read the state of the current context + var stateSize = GetStateSize(); + var stateMemory = Marshal.AllocHGlobal((nint)stateSize); + try + { + // Copy state from this context into memory + GetState(stateMemory, stateSize); + + // Create a new context + var newCtx = Create(ModelHandle, lparams); + + // Copy state into new context + newCtx.SetState(stateMemory); + + return newCtx; + } + finally + { + Marshal.FreeHGlobal(stateMemory); + } + } #endregion /// @@ -188,5 +218,69 @@ namespace LLama.Native } } } + + #region state + /// + /// Get the size of the state, when saved as bytes + /// + public ulong GetStateSize() + { + return NativeApi.llama_get_state_size(this); + } + + /// + /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. + /// + /// Destination to write to + /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) + /// The number of bytes written to dest + /// Thrown if dest is too small + public unsafe ulong GetState(byte* dest, ulong size) + { + return GetState(new IntPtr(dest), size); + } + + /// + /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. + /// + /// Destination to write to + /// Number of bytes available to write to in dest (check required size with `GetStateSize()`) + /// The number of bytes written to dest + /// Thrown if dest is too small + public ulong GetState(IntPtr dest, ulong size) + { + var required = GetStateSize(); + if (size < required) + throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); + + unsafe + { + return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); + } + } + + /// + /// Set the raw state of this context + /// + /// The pointer to read the state from + /// Number of bytes read from the src pointer + public unsafe ulong SetState(byte* src) + { + return SetState(new IntPtr(src)); + } + + /// + /// Set the raw state of this context + /// + /// The pointer to read the state from + /// Number of bytes read from the src pointer + public ulong SetState(IntPtr src) + { + unsafe + { + return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); + } + } + #endregion } }