Improved Cloningtags/v0.5.1
| @@ -26,5 +26,32 @@ namespace LLama.Unittest | |||
| Assert.Equal(2048, _model.ContextSize); | |||
| Assert.Equal(4096, _model.EmbeddingSize); | |||
| } | |||
| [Fact] | |||
| public void CloneContext() | |||
| { | |||
| var original = _model.CreateContext(_params); | |||
| // Evaluate something (doesn't matter what, as long as it begins with token 1) | |||
| original.Eval(new[] { 1, 42, 321 }, 0); | |||
| // Clone current state | |||
| var clone = original.Clone(); | |||
| // Now evaluate something more | |||
| var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); | |||
| var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); | |||
| // Assert that the context replied differently each time | |||
| Assert.NotEqual(reply1a, reply2a); | |||
| // Give the same prompts to the cloned state | |||
| var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); | |||
| var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); | |||
| // Assert that the cloned context replied in the same way as originally | |||
| Assert.Equal(reply1a, reply1b); | |||
| Assert.Equal(reply2a, reply2b); | |||
| } | |||
| } | |||
| } | |||
| @@ -19,7 +19,7 @@ namespace LLama | |||
| /// <summary> | |||
| /// A llama_context, which holds all the context required to interact with a model | |||
| /// </summary> | |||
| public class LLamaContext | |||
| public sealed class LLamaContext | |||
| : IDisposable | |||
| { | |||
| private readonly ILLamaLogger? _logger; | |||
| @@ -111,15 +111,8 @@ namespace LLama | |||
| public LLamaContext Clone() | |||
| { | |||
| using var pin = Params.ToLlamaContextParams(out var lparams); | |||
| // Create a blank new context for the model | |||
| var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params); | |||
| // Copy across the state | |||
| using var state = GetState(); | |||
| ctx.LoadState(state); | |||
| return ctx; | |||
| var clone = _ctx.Clone(lparams); | |||
| return new LLamaContext(clone, Params); | |||
| } | |||
| /// <summary> | |||
| @@ -197,7 +190,7 @@ namespace LLama | |||
| /// <returns></returns> | |||
| public State GetState() | |||
| { | |||
| var stateSize = NativeApi.llama_get_state_size(_ctx); | |||
| var stateSize = _ctx.GetStateSize(); | |||
| unsafe | |||
| { | |||
| @@ -206,15 +199,17 @@ namespace LLama | |||
| try | |||
| { | |||
| // Copy the state data into "big memory", discover the actual size required | |||
| var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory); | |||
| var actualSize = _ctx.GetState(bigMemory, stateSize); | |||
| // if big memory is nearly completely full (within 1MB) early exit and skip the extra copying | |||
| if (actualSize >= stateSize - 1_000_000) | |||
| return new State(bigMemory); | |||
| // Allocate a smaller buffer | |||
| // Allocate a smaller buffer which is exactly the right size | |||
| smallMemory = Marshal.AllocHGlobal((nint)actualSize); | |||
| // Copy into the smaller buffer and free the large one to save excess memory usage | |||
| Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize); | |||
| Marshal.FreeHGlobal(bigMemory); | |||
| bigMemory = IntPtr.Zero; | |||
| return new State(smallMemory); | |||
| } | |||
| @@ -274,7 +269,7 @@ namespace LLama | |||
| { | |||
| unsafe | |||
| { | |||
| NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer()); | |||
| _ctx.SetState((byte*)state.DangerousGetHandle().ToPointer()); | |||
| } | |||
| } | |||
| @@ -498,10 +493,8 @@ namespace LLama | |||
| } | |||
| /// <inheritdoc /> | |||
| public virtual void Dispose() | |||
| public void Dispose() | |||
| { | |||
| GC.SuppressFinalize(this); | |||
| _ctx.Dispose(); | |||
| } | |||
| @@ -8,7 +8,7 @@ namespace LLama | |||
| /// <summary> | |||
| /// The embedder for LLama, which supports getting embeddings from text. | |||
| /// </summary> | |||
| public class LLamaEmbedder | |||
| public sealed class LLamaEmbedder | |||
| : IDisposable | |||
| { | |||
| private readonly LLamaContext _ctx; | |||
| @@ -1,5 +1,6 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| @@ -89,6 +90,35 @@ namespace LLama.Native | |||
| return new(ctx_ptr, model); | |||
| } | |||
| /// <summary> | |||
| /// Create a new llama context with a clone of the current llama context state | |||
| /// </summary> | |||
| /// <param name="lparams"></param> | |||
| /// <returns></returns> | |||
| public SafeLLamaContextHandle Clone(LLamaContextParams lparams) | |||
| { | |||
| // Allocate space to read the state of the current context | |||
| var stateSize = GetStateSize(); | |||
| var stateMemory = Marshal.AllocHGlobal((nint)stateSize); | |||
| try | |||
| { | |||
| // Copy state from this context into memory | |||
| GetState(stateMemory, stateSize); | |||
| // Create a new context | |||
| var newCtx = Create(ModelHandle, lparams); | |||
| // Copy state into new context | |||
| newCtx.SetState(stateMemory); | |||
| return newCtx; | |||
| } | |||
| finally | |||
| { | |||
| Marshal.FreeHGlobal(stateMemory); | |||
| } | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| @@ -188,5 +218,69 @@ namespace LLama.Native | |||
| } | |||
| } | |||
| } | |||
| #region state | |||
| /// <summary> | |||
| /// Get the size of the state, when saved as bytes | |||
| /// </summary> | |||
| public ulong GetStateSize() | |||
| { | |||
| return NativeApi.llama_get_state_size(this); | |||
| } | |||
| /// <summary> | |||
| /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. | |||
| /// </summary> | |||
| /// <param name="dest">Destination to write to</param> | |||
| /// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param> | |||
| /// <returns>The number of bytes written to dest</returns> | |||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception> | |||
| public unsafe ulong GetState(byte* dest, ulong size) | |||
| { | |||
| return GetState(new IntPtr(dest), size); | |||
| } | |||
| /// <summary> | |||
| /// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer. | |||
| /// </summary> | |||
| /// <param name="dest">Destination to write to</param> | |||
| /// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param> | |||
| /// <returns>The number of bytes written to dest</returns> | |||
| /// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception> | |||
| public ulong GetState(IntPtr dest, ulong size) | |||
| { | |||
| var required = GetStateSize(); | |||
| if (size < required) | |||
| throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); | |||
| unsafe | |||
| { | |||
| return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer()); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Set the raw state of this context | |||
| /// </summary> | |||
| /// <param name="src">The pointer to read the state from</param> | |||
| /// <returns>Number of bytes read from the src pointer</returns> | |||
| public unsafe ulong SetState(byte* src) | |||
| { | |||
| return SetState(new IntPtr(src)); | |||
| } | |||
| /// <summary> | |||
| /// Set the raw state of this context | |||
| /// </summary> | |||
| /// <param name="src">The pointer to read the state from</param> | |||
| /// <returns>Number of bytes read from the src pointer</returns> | |||
| public ulong SetState(IntPtr src) | |||
| { | |||
| unsafe | |||
| { | |||
| return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer()); | |||
| } | |||
| } | |||
| #endregion | |||
| } | |||
| } | |||