Browse Source

Merge pull request #119 from martindevans/improved_cloning

Improved Cloning
tags/v0.5.1
Martin Evans GitHub 2 years ago
parent
commit
d367317c2c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 134 additions and 20 deletions
  1. +27
    -0
      LLama.Unittest/BasicTest.cs
  2. +12
    -19
      LLama/LLamaContext.cs
  3. +1
    -1
      LLama/LLamaEmbedder.cs
  4. +94
    -0
      LLama/Native/SafeLLamaContextHandle.cs

+ 27
- 0
LLama.Unittest/BasicTest.cs View File

@@ -26,5 +26,32 @@ namespace LLama.Unittest
Assert.Equal(2048, _model.ContextSize);
Assert.Equal(4096, _model.EmbeddingSize);
}

[Fact]
public void CloneContext()
{
var original = _model.CreateContext(_params);

// Evaluate something (doesn't matter what, as long as it begins with token 1)
original.Eval(new[] { 1, 42, 321 }, 0);

// Clone current state
var clone = original.Clone();

// Now evaluate something more
var reply1a = original.Eval(new[] { 4, 5, 6 }, 3);
var reply2a = original.Eval(new[] { 7, 8, 9 }, 6);

// Assert that the context replied differently each time
Assert.NotEqual(reply1a, reply2a);

// Give the same prompts to the cloned state
var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3);
var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6);

// Assert that the cloned context replied in the same way as originally
Assert.Equal(reply1a, reply1b);
Assert.Equal(reply2a, reply2b);
}
}
}

+ 12
- 19
LLama/LLamaContext.cs View File

@@ -19,7 +19,7 @@ namespace LLama
/// <summary>
/// A llama_context, which holds all the context required to interact with a model
/// </summary>
public class LLamaContext
public sealed class LLamaContext
: IDisposable
{
private readonly ILLamaLogger? _logger;
@@ -111,15 +111,8 @@ namespace LLama
public LLamaContext Clone()
{
using var pin = Params.ToLlamaContextParams(out var lparams);

// Create a blank new context for the model
var ctx = new LLamaContext(SafeLLamaContextHandle.Create(NativeHandle.ModelHandle, lparams), Params);

// Copy across the state
using var state = GetState();
ctx.LoadState(state);

return ctx;
var clone = _ctx.Clone(lparams);
return new LLamaContext(clone, Params);
}

/// <summary>
@@ -197,7 +190,7 @@ namespace LLama
/// <returns></returns>
public State GetState()
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
var stateSize = _ctx.GetStateSize();

unsafe
{
@@ -206,15 +199,17 @@ namespace LLama
try
{
// Copy the state data into "big memory", discover the actual size required
var actualSize = NativeApi.llama_copy_state_data(_ctx, (byte*)bigMemory);
var actualSize = _ctx.GetState(bigMemory, stateSize);

// if big memory is nearly completely full (within 1MB) early exit and skip the extra copying
if (actualSize >= stateSize - 1_000_000)
return new State(bigMemory);

// Allocate a smaller buffer
// Allocate a smaller buffer which is exactly the right size
smallMemory = Marshal.AllocHGlobal((nint)actualSize);

// Copy into the smaller buffer and free the large one to save excess memory usage
Buffer.MemoryCopy(bigMemory.ToPointer(), smallMemory.ToPointer(), actualSize, actualSize);
Marshal.FreeHGlobal(bigMemory);
bigMemory = IntPtr.Zero;

return new State(smallMemory);
}
@@ -274,7 +269,7 @@ namespace LLama
{
unsafe
{
NativeApi.llama_set_state_data(_ctx, (byte*)state.DangerousGetHandle().ToPointer());
_ctx.SetState((byte*)state.DangerousGetHandle().ToPointer());
}
}

@@ -498,10 +493,8 @@ namespace LLama
}

/// <inheritdoc />
public virtual void Dispose()
public void Dispose()
{
GC.SuppressFinalize(this);

_ctx.Dispose();
}



+ 1
- 1
LLama/LLamaEmbedder.cs View File

@@ -8,7 +8,7 @@ namespace LLama
/// <summary>
/// The embedder for LLama, which supports getting embeddings from text.
/// </summary>
public class LLamaEmbedder
public sealed class LLamaEmbedder
: IDisposable
{
private readonly LLamaContext _ctx;


+ 94
- 0
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;

@@ -89,6 +90,35 @@ namespace LLama.Native

return new(ctx_ptr, model);
}

/// <summary>
/// Create a new llama context with a clone of the current llama context state
/// </summary>
/// <param name="lparams"></param>
/// <returns></returns>
public SafeLLamaContextHandle Clone(LLamaContextParams lparams)
{
// Allocate space to read the state of the current context
var stateSize = GetStateSize();
var stateMemory = Marshal.AllocHGlobal((nint)stateSize);
try
{
// Copy state from this context into memory
GetState(stateMemory, stateSize);

// Create a new context
var newCtx = Create(ModelHandle, lparams);

// Copy state into new context
newCtx.SetState(stateMemory);

return newCtx;
}
finally
{
Marshal.FreeHGlobal(stateMemory);
}
}
#endregion

/// <summary>
@@ -188,5 +218,69 @@ namespace LLama.Native
}
}
}

#region state
/// <summary>
/// Get the size of the state, when saved as bytes
/// </summary>
public ulong GetStateSize()
{
return NativeApi.llama_get_state_size(this);
}

/// <summary>
/// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
/// </summary>
/// <param name="dest">Destination to write to</param>
/// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param>
/// <returns>The number of bytes written to dest</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception>
public unsafe ulong GetState(byte* dest, ulong size)
{
return GetState(new IntPtr(dest), size);
}

/// <summary>
/// Get the raw state of this context, encoded as bytes. Data is written into the `dest` pointer.
/// </summary>
/// <param name="dest">Destination to write to</param>
/// <param name="size">Number of bytes available to write to in dest (check required size with `GetStateSize()`)</param>
/// <returns>The number of bytes written to dest</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if dest is too small</exception>
public ulong GetState(IntPtr dest, ulong size)
{
var required = GetStateSize();
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");

unsafe
{
return NativeApi.llama_copy_state_data(this, (byte*)dest.ToPointer());
}
}

/// <summary>
/// Set the raw state of this context
/// </summary>
/// <param name="src">The pointer to read the state from</param>
/// <returns>Number of bytes read from the src pointer</returns>
public unsafe ulong SetState(byte* src)
{
return SetState(new IntPtr(src));
}

/// <summary>
/// Set the raw state of this context
/// </summary>
/// <param name="src">The pointer to read the state from</param>
/// <returns>Number of bytes read from the src pointer</returns>
public ulong SetState(IntPtr src)
{
unsafe
{
return NativeApi.llama_set_state_data(this, (byte*)src.ToPointer());
}
}
#endregion
}
}

Loading…
Cancel
Save