From 1e69e265b62bec36fc87363bb404a30db00a85e5 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sat, 6 Jan 2024 23:39:25 +0000 Subject: [PATCH] Moved some native methods to do with creating/destroying resources into their respective handles. There is **no** safe way to call most of these methods, everything must be done through through handles. --- LLama/Native/NativeApi.Load.cs | 2 +- LLama/Native/NativeApi.cs | 39 ++---------------- LLama/Native/SafeLLamaContextHandle.cs | 56 +++++++++++++++++--------- LLama/Native/SafeLLamaHandleBase.cs | 2 +- LLama/Native/SafeLlamaModelHandle.cs | 35 +++++++++++++--- 5 files changed, 72 insertions(+), 62 deletions(-) diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 5ae02c1a..9153c1f2 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -329,7 +329,7 @@ namespace LLama.Native #endif } - private const string libraryName = "libllama"; + internal const string libraryName = "libllama"; private const string cudaVersionFile = "version.json"; private const string loggingPrefix = "[LLamaSharp Native]"; private static bool enableLogging = false; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 2a34820d..22cc483e 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -25,8 +25,10 @@ namespace LLama.Native /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. /// /// - [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_empty_call(); + public static void llama_empty_call() + { + llama_mmap_supported(); + } /// /// Get the maximum number of devices supported by llama.cpp @@ -70,25 +72,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_mlock_supported(); - /// - /// Load all of the weights of a model into memory. - /// - /// - /// - /// The loaded model, or null on failure. - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); - - /// - /// Create a new llama_context with the given model. - /// Return value should always be wrapped in SafeLLamaContextHandle! - /// - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); - /// /// Initialize the llama + ggml backend /// Call once at the start of the program @@ -96,20 +79,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] private static extern void llama_backend_init(bool numa); - /// - /// Frees all allocated memory in the given llama_context - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_free(IntPtr ctx); - - /// - /// Frees all allocated memory associated with a model - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_free_model(IntPtr model); - /// /// Apply a LoRA adapter to a loaded model /// path_base_model is the path to a higher quality model to use as a base for diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 98b51078..b6c8eca6 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -8,6 +9,7 @@ namespace LLama.Native /// /// A safe wrapper around a llama_context /// + // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) public sealed class SafeLLamaContextHandle : SafeLLamaHandleBase { @@ -36,26 +38,10 @@ namespace LLama.Native #endregion #region construction/destruction - /// - /// Create a new SafeLLamaContextHandle - /// - /// pointer to an allocated llama_context - /// the model which this context was created from - public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) - : base(handle) - { - // Increment the model reference count while this context exists - _model = model; - var success = false; - _model.DangerousAddRef(ref success); - if (!success) - throw new RuntimeError("Failed to increment model refcount"); - } - /// protected override bool ReleaseHandle() { - NativeApi.llama_free(DangerousGetHandle()); + llama_free(handle); SetHandle(IntPtr.Zero); // Decrement refcount on model @@ -84,12 +70,42 @@ namespace LLama.Native /// public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) { - var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); - if (ctx_ptr == IntPtr.Zero) + var ctx = llama_new_context_with_model(model, lparams); + if (ctx == null) throw new RuntimeError("Failed to create context from model"); - return new(ctx_ptr, model); + // Increment the model reference count while this context exists. + // DangerousAddRef throws if it fails, so there is no need to check "success" + ctx._model = model; + var success = false; + ctx._model.DangerousAddRef(ref success); + + return ctx; + } + #endregion + + #region Native API + static SafeLLamaContextHandle() + { + // This ensures that `NativeApi` has been loaded before calling the two native methods below + NativeApi.llama_empty_call(); } + + /// + /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**! + /// + /// + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); + + /// + /// Frees all allocated memory in the given llama_context + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_free(IntPtr ctx); #endregion /// diff --git a/LLama/Native/SafeLLamaHandleBase.cs b/LLama/Native/SafeLLamaHandleBase.cs index 6371b327..7171c803 100644 --- a/LLama/Native/SafeLLamaHandleBase.cs +++ b/LLama/Native/SafeLLamaHandleBase.cs @@ -31,6 +31,6 @@ namespace LLama.Native /// public override string ToString() - => $"0x{handle.ToString("x16")}"; + => $"0x{handle:x16}"; } } diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 291cfbc2..1e9fbca3 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; using LLama.Extensions; @@ -10,6 +11,7 @@ namespace LLama.Native /// /// A reference to a set of llama model weights /// + // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) public sealed class SafeLlamaModelHandle : SafeLLamaHandleBase { @@ -47,8 +49,7 @@ namespace LLama.Native /// protected override bool ReleaseHandle() { - NativeApi.llama_free_model(DangerousGetHandle()); - SetHandle(IntPtr.Zero); + llama_free_model(handle); return true; } @@ -61,13 +62,37 @@ namespace LLama.Native /// public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) { - var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); - if (model_ptr == null) + var model = llama_load_model_from_file(modelPath, lparams); + if (model == null) throw new RuntimeError($"Failed to load model {modelPath}."); - return model_ptr; + return model; } + #region native API + static SafeLlamaModelHandle() + { + // This ensures that `NativeApi` has been loaded before calling the two native methods below + NativeApi.llama_empty_call(); + } + + /// + /// Load all of the weights of a model into memory. + /// + /// + /// + /// The loaded model, or null on failure. + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); + + /// + /// Frees all allocated memory associated with a model + /// + /// + [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] + private static extern void llama_free_model(IntPtr model); + #endregion + #region LoRA ///