diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index 6fc206ed..5023f815 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -26,5 +26,32 @@ namespace LLama.Unittest
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
}
+
+ [Fact]
+ public void CloneContext()
+ {
+ var original = _model.CreateContext(_params);
+
+ // Evaluate something (doesn't matter what, as long as it begins with token 1)
+ original.Eval(new[] { 1, 42, 321 }, 0);
+
+ // Clone current state
+ var clone = original.Clone();
+
+ // Now evaluate something more
+ var reply1a = original.Eval(new[] { 4, 5, 6 }, 3);
+ var reply2a = original.Eval(new[] { 7, 8, 9 }, 6);
+
+ // Assert that the context replied differently each time
+ Assert.NotEqual(reply1a, reply2a);
+
+ // Give the same prompts to the cloned state
+ var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3);
+ var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6);
+
+ // Assert that the cloned context replied in the same way as originally
+ Assert.Equal(reply1a, reply1b);
+ Assert.Equal(reply2a, reply2b);
+ }
}
}
\ No newline at end of file
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 9501d570..159c641c 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -19,7 +19,7 @@ namespace LLama
///
/// A llama_context, which holds all the context required to interact with a model
///
- public class LLamaContext
+ public sealed class LLamaContext
: IDisposable
{
private readonly ILLamaLogger? _logger;
@@ -111,15 +111,8 @@ namespace LLama
public LLamaContext Clone()
{
using var pin = Params.ToLlamaContextParams(out var lparams);
-
- // Create a blank new context for the model
- var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params);
-
- // Copy across the state
- using var state = GetState();
- ctx.LoadState(state);
-
- return ctx;
+ var clone = _ctx.Clone(lparams);
+ return new LLamaContext(clone, Params);
}
///
@@ -197,7 +190,7 @@ namespace LLama
///
public State GetState()
{
- var stateSize = NativeApi.llama_get_state_size(_ctx);
+ var stateSize = _ctx.GetStateSize();
unsafe
{
@@ -206,15 +199,17 @@ namespace LLama
try
{
// Copy the state data into "big memory", discover the actual size required
- var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);
+ var actualSize = _ctx.GetState(bigMemory, stateSize);
+
+ // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying
+ if (actualSize >= stateSize - 1_000_000)
+ return new State(bigMemory);
- // Allocate a smaller buffer
+ // Allocate a smaller buffer which is exactly the right size
smallMemory = Marshal.AllocHGlobal((nint)actualSize);
// Copy into the smaller buffer and free the large one to save excess memory usage
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
- Marshal.FreeHGlobal(bigMemory);
- bigMemory = IntPtr.Zero;
return new State(smallMemory);
}
@@ -274,7 +269,7 @@ namespace LLama
{
unsafe
{
- NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
+ _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer());
}
}
@@ -498,10 +493,8 @@ namespace LLama
}
///
- public virtual void Dispose()
+ public void Dispose()
{
- GC.SuppressFinalize(this);
-
_ctx.Dispose();
}
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 57c305b2..5980d17c 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -8,7 +8,7 @@ namespace LLama
///
/// The embedder for LLama, which supports getting embeddings from text.
///
- public class LLamaEmbedder
+ public sealed class LLamaEmbedder
: IDisposable
{
private readonly LLamaContext _ctx;
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 2e499196..86c1c71c 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,5 +1,6 @@
using System;
using System.Buffers;
+using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@@ -89,6 +90,35 @@ namespace LLama.Native
return new(ctx_ptr, model);
}
+
+ ///
+ /// Create a new llama context with a clone of the current llama context state
+ ///
+ ///
+ ///
+ public SafeLLamaContextHandle Clone(LLamaContextParams lparams)
+ {
+ // Allocate space to read the state of the current context
+ var stateSize = GetStateSize();
+ var stateMemory = Marshal.AllocHGlobal((nint)stateSize);
+ try
+ {
+ // Copy state from this context into memory
+ GetState(stateMemory, stateSize);
+
+ // Create a new context
+ var newCtx = Create(ModelHandle, lparams);
+
+ // Copy state into new context
+ newCtx.SetState(stateMemory);
+
+ return newCtx;
+ }
+ finally
+ {
+ Marshal.FreeHGlobal(stateMemory);
+ }
+ }
#endregion
///
@@ -188,5 +218,69 @@ namespace LLama.Native
}
}
}
+
+ #region state
+ ///
+ /// Get the size of the state, when saved as bytes
+ ///
+ public ulong GetStateSize()
+ {
+ return NativeApi.llama_get_state_size(this);
+ }
+
+ ///
+ /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
+ ///
+ /// Destination to write to
+ /// Number of bytes available to write to in dest (check required size with `GetStateSize()`)
+ /// The number of bytes written to dest
+ /// Thrown if dest is too small
+ public unsafe ulong GetState(byte* dest, ulong size)
+ {
+ return GetState(new IntPtr(dest), size);
+ }
+
+ ///
+ /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
+ ///
+ /// Destination to write to
+ /// Number of bytes available to write to in dest (check required size with `GetStateSize()`)
+ /// The number of bytes written to dest
+ /// Thrown if dest is too small
+ public ulong GetState(IntPtr dest, ulong size)
+ {
+ var required = GetStateSize();
+ if (size < required)
+ throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");
+
+ unsafe
+ {
+ return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer());
+ }
+ }
+
+ ///
+ /// Set the raw state of this context
+ ///
+ /// The pointer to read the state from
+ /// Number of bytes read from the src pointer
+ public unsafe ulong SetState(byte* src)
+ {
+ return SetState(new IntPtr(src));
+ }
+
+ ///
+ /// Set the raw state of this context
+ ///
+ /// The pointer to read the state from
+ /// Number of bytes read from the src pointer
+ public ulong SetState(IntPtr src)
+ {
+ unsafe
+ {
+ return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer());
+ }
+ }
+ #endregion
}
}