diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs
index 9feb6202..85ec34f2 100644
--- a/LLama.Examples/Program.cs
+++ b/LLama.Examples/Program.cs
@@ -7,7 +7,11 @@ Console.WriteLine(" __ __ ____ _
Console.WriteLine("======================================================================================================");
-NativeLibraryConfig.Instance.WithCuda().WithLogs();
+NativeLibraryConfig
+ .Instance
+ .WithCuda()
+ .WithLogs()
+ .WithAvx(NativeLibraryConfig.AvxLevel.Avx512);
NativeApi.llama_empty_call();
Console.WriteLine();
diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs
index 395d447c..5dad1fc3 100644
--- a/LLama/Native/LLamaKvCacheView.cs
+++ b/LLama/Native/LLamaKvCacheView.cs
@@ -1,4 +1,5 @@
-using System.Runtime.InteropServices;
+using System;
+using System.Runtime.InteropServices;
namespace LLama.Native;
@@ -18,7 +19,6 @@ public struct LLamaKvCacheViewCell
///
/// An updateable view of the KV cache (llama_kv_cache_view)
///
-//todo: rewrite to safe handle?
[StructLayout(LayoutKind.Sequential)]
public unsafe struct LLamaKvCacheView
{
@@ -52,6 +52,84 @@ public unsafe struct LLamaKvCacheView
LLamaSeqId* cells_sequences;
}
+///
+/// A safe handle for a LLamaKvCacheView
+///
+public class LLamaKvCacheViewSafeHandle
+ : SafeLLamaHandleBase
+{
+ private readonly SafeLLamaContextHandle _ctx;
+ private LLamaKvCacheView _view;
+
+ ///
+ /// Initialize a LLamaKvCacheViewSafeHandle which will call `llama_kv_cache_view_free` when disposed
+ ///
+ ///
+ ///
+ public LLamaKvCacheViewSafeHandle(SafeLLamaContextHandle ctx, LLamaKvCacheView view)
+ : base(IntPtr.MaxValue, true)
+ {
+ _ctx = ctx;
+ _view = view;
+ }
+
+ ///
+ /// Allocate a new llama_kv_cache_view_free
+ ///
+ ///
+ /// The maximum number of sequences visible in this view per cell
+ ///
+ public static LLamaKvCacheViewSafeHandle Allocate(SafeLLamaContextHandle ctx, int maxSequences)
+ {
+ var result = NativeApi.llama_kv_cache_view_init(ctx, maxSequences);
+ return new LLamaKvCacheViewSafeHandle(ctx, result);
+ }
+
+ ///
+ protected override bool ReleaseHandle()
+ {
+ NativeApi.llama_kv_cache_view_free(ref _view);
+ SetHandle(IntPtr.Zero);
+
+ return true;
+ }
+
+ ///
+ /// Update this view
+ ///
+ public void Update()
+ {
+ NativeApi.llama_kv_cache_view_update(_ctx, ref _view);
+ }
+
+ ///
+ /// Count the number of used cells in the KV cache
+ ///
+ ///
+ public int CountCells()
+ {
+ return NativeApi.llama_get_kv_cache_used_cells(_ctx);
+ }
+
+ ///
+ /// Count the number of tokens in the KV cache. If a token is assigned to multiple sequences it will be countered multiple times
+ ///
+ ///
+ public int CountTokens()
+ {
+ return NativeApi.llama_get_kv_cache_token_count(_ctx);
+ }
+
+ ///
+ /// Get the raw KV cache view
+ ///
+ ///
+ public ref LLamaKvCacheView GetView()
+ {
+ return ref _view;
+ }
+}
+
partial class NativeApi
{
///
@@ -66,9 +144,8 @@ partial class NativeApi
///
/// Free a KV cache view. (use only for debugging purposes)
///
- ///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe void llama_kv_cache_view_free(LLamaKvCacheView* view);
+ public static extern void llama_kv_cache_view_free(ref LLamaKvCacheView view);
///
/// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
@@ -76,7 +153,7 @@ partial class NativeApi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, LLamaKvCacheView* view);
+ public static extern void llama_kv_cache_view_update(SafeLLamaContextHandle ctx, ref LLamaKvCacheView view);
///
/// Returns the number of tokens in the KV cache (slow, use only for debug)