| @@ -329,7 +329,7 @@ namespace LLama.Native | |||||
| #endif | #endif | ||||
| } | } | ||||
| private const string libraryName = "libllama"; | |||||
| internal const string libraryName = "libllama"; | |||||
| private const string cudaVersionFile = "version.json"; | private const string cudaVersionFile = "version.json"; | ||||
| private const string loggingPrefix = "[LLamaSharp Native]"; | private const string loggingPrefix = "[LLamaSharp Native]"; | ||||
| private static bool enableLogging = false; | private static bool enableLogging = false; | ||||
| @@ -25,8 +25,10 @@ namespace LLama.Native | |||||
| /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. | /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. | ||||
| /// </summary> | /// </summary> | ||||
| /// <returns></returns> | /// <returns></returns> | ||||
| [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern bool llama_empty_call(); | |||||
| public static void llama_empty_call() | |||||
| { | |||||
| llama_mmap_supported(); | |||||
| } | |||||
| /// <summary> | /// <summary> | ||||
| /// Get the maximum number of devices supported by llama.cpp | /// Get the maximum number of devices supported by llama.cpp | ||||
| @@ -70,25 +72,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| public static extern bool llama_mlock_supported(); | public static extern bool llama_mlock_supported(); | ||||
| /// <summary> | |||||
| /// Load all of the weights of a model into memory. | |||||
| /// </summary> | |||||
| /// <param name="path_model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns>The loaded model, or null on failure.</returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); | |||||
| /// <summary> | |||||
| /// Create a new llama_context with the given model. | |||||
| /// Return value should always be wrapped in SafeLLamaContextHandle! | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); | |||||
| /// <summary> | /// <summary> | ||||
| /// Initialize the llama + ggml backend | /// Initialize the llama + ggml backend | ||||
| /// Call once at the start of the program | /// Call once at the start of the program | ||||
| @@ -96,20 +79,6 @@ namespace LLama.Native | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | ||||
| private static extern void llama_backend_init(bool numa); | private static extern void llama_backend_init(bool numa); | ||||
| /// <summary> | |||||
| /// Frees all allocated memory in the given llama_context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_free(IntPtr ctx); | |||||
| /// <summary> | |||||
| /// Frees all allocated memory associated with a model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| public static extern void llama_free_model(IntPtr model); | |||||
| /// <summary> | /// <summary> | ||||
| /// Apply a LoRA adapter to a loaded model | /// Apply a LoRA adapter to a loaded model | ||||
| /// path_base_model is the path to a higher quality model to use as a base for | /// path_base_model is the path to a higher quality model to use as a base for | ||||
| @@ -1,5 +1,6 @@ | |||||
| using System; | using System; | ||||
| using System.Buffers; | using System.Buffers; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| @@ -8,6 +9,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// A safe wrapper around a llama_context | /// A safe wrapper around a llama_context | ||||
| /// </summary> | /// </summary> | ||||
| // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) | |||||
| public sealed class SafeLLamaContextHandle | public sealed class SafeLLamaContextHandle | ||||
| : SafeLLamaHandleBase | : SafeLLamaHandleBase | ||||
| { | { | ||||
| @@ -36,26 +38,10 @@ namespace LLama.Native | |||||
| #endregion | #endregion | ||||
| #region construction/destruction | #region construction/destruction | ||||
| /// <summary> | |||||
| /// Create a new SafeLLamaContextHandle | |||||
| /// </summary> | |||||
| /// <param name="handle">pointer to an allocated llama_context</param> | |||||
| /// <param name="model">the model which this context was created from</param> | |||||
| public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) | |||||
| : base(handle) | |||||
| { | |||||
| // Increment the model reference count while this context exists | |||||
| _model = model; | |||||
| var success = false; | |||||
| _model.DangerousAddRef(ref success); | |||||
| if (!success) | |||||
| throw new RuntimeError("Failed to increment model refcount"); | |||||
| } | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free(DangerousGetHandle()); | |||||
| llama_free(handle); | |||||
| SetHandle(IntPtr.Zero); | SetHandle(IntPtr.Zero); | ||||
| // Decrement refcount on model | // Decrement refcount on model | ||||
| @@ -84,12 +70,42 @@ namespace LLama.Native | |||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) | public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams) | ||||
| { | { | ||||
| var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams); | |||||
| if (ctx_ptr == IntPtr.Zero) | |||||
| var ctx = llama_new_context_with_model(model, lparams); | |||||
| if (ctx == null) | |||||
| throw new RuntimeError("Failed to create context from model"); | throw new RuntimeError("Failed to create context from model"); | ||||
| return new(ctx_ptr, model); | |||||
| // Increment the model reference count while this context exists. | |||||
| // DangerousAddRef throws if it fails, so there is no need to check "success" | |||||
| ctx._model = model; | |||||
| var success = false; | |||||
| ctx._model.DangerousAddRef(ref success); | |||||
| return ctx; | |||||
| } | |||||
| #endregion | |||||
| #region Native API | |||||
| static SafeLLamaContextHandle() | |||||
| { | |||||
| // This ensures that `NativeApi` has been loaded before calling the two native methods below | |||||
| NativeApi.llama_empty_call(); | |||||
| } | } | ||||
| /// <summary> | |||||
| /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**! | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns></returns> | |||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params); | |||||
| /// <summary> | |||||
| /// Frees all allocated memory in the given llama_context | |||||
| /// </summary> | |||||
| /// <param name="ctx"></param> | |||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| private static extern void llama_free(IntPtr ctx); | |||||
| #endregion | #endregion | ||||
| /// <summary> | /// <summary> | ||||
| @@ -31,6 +31,6 @@ namespace LLama.Native | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| public override string ToString() | public override string ToString() | ||||
| => $"0x{handle.ToString("x16")}"; | |||||
| => $"0x{handle:x16}"; | |||||
| } | } | ||||
| } | } | ||||
| @@ -1,6 +1,7 @@ | |||||
| using System; | using System; | ||||
| using System.Collections.Generic; | using System.Collections.Generic; | ||||
| using System.Diagnostics; | using System.Diagnostics; | ||||
| using System.Runtime.InteropServices; | |||||
| using System.Text; | using System.Text; | ||||
| using LLama.Exceptions; | using LLama.Exceptions; | ||||
| using LLama.Extensions; | using LLama.Extensions; | ||||
| @@ -10,6 +11,7 @@ namespace LLama.Native | |||||
| /// <summary> | /// <summary> | ||||
| /// A reference to a set of llama model weights | /// A reference to a set of llama model weights | ||||
| /// </summary> | /// </summary> | ||||
| // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API) | |||||
| public sealed class SafeLlamaModelHandle | public sealed class SafeLlamaModelHandle | ||||
| : SafeLLamaHandleBase | : SafeLLamaHandleBase | ||||
| { | { | ||||
| @@ -47,8 +49,7 @@ namespace LLama.Native | |||||
| /// <inheritdoc /> | /// <inheritdoc /> | ||||
| protected override bool ReleaseHandle() | protected override bool ReleaseHandle() | ||||
| { | { | ||||
| NativeApi.llama_free_model(DangerousGetHandle()); | |||||
| SetHandle(IntPtr.Zero); | |||||
| llama_free_model(handle); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -61,13 +62,37 @@ namespace LLama.Native | |||||
| /// <exception cref="RuntimeError"></exception> | /// <exception cref="RuntimeError"></exception> | ||||
| public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) | public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) | ||||
| { | { | ||||
| var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); | |||||
| if (model_ptr == null) | |||||
| var model = llama_load_model_from_file(modelPath, lparams); | |||||
| if (model == null) | |||||
| throw new RuntimeError($"Failed to load model {modelPath}."); | throw new RuntimeError($"Failed to load model {modelPath}."); | ||||
| return model_ptr; | |||||
| return model; | |||||
| } | } | ||||
| #region native API | |||||
| static SafeLlamaModelHandle() | |||||
| { | |||||
| // This ensures that `NativeApi` has been loaded before calling the two native methods below | |||||
| NativeApi.llama_empty_call(); | |||||
| } | |||||
| /// <summary> | |||||
| /// Load all of the weights of a model into memory. | |||||
| /// </summary> | |||||
| /// <param name="path_model"></param> | |||||
| /// <param name="params"></param> | |||||
| /// <returns>The loaded model, or null on failure.</returns> | |||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params); | |||||
| /// <summary> | |||||
| /// Frees all allocated memory associated with a model | |||||
| /// </summary> | |||||
| /// <param name="model"></param> | |||||
| [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)] | |||||
| private static extern void llama_free_model(IntPtr model); | |||||
| #endregion | |||||
| #region LoRA | #region LoRA | ||||
| /// <summary> | /// <summary> | ||||