diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs
index 5ae02c1a..9153c1f2 100644
--- a/LLama/Native/NativeApi.Load.cs
+++ b/LLama/Native/NativeApi.Load.cs
@@ -329,7 +329,7 @@ namespace LLama.Native
#endif
}
- private const string libraryName = "libllama";
+ internal const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";
private const string loggingPrefix = "[LLamaSharp Native]";
private static bool enableLogging = false;
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 2a34820d..22cc483e 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -25,8 +25,10 @@ namespace LLama.Native
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
///
///
- [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_empty_call();
+ public static void llama_empty_call()
+ {
+ llama_mmap_supported();
+ }
///
/// Get the maximum number of devices supported by llama.cpp
@@ -70,25 +72,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported();
- ///
- /// Load all of the weights of a model into memory.
- ///
- ///
- ///
- /// The loaded model, or null on failure.
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);
-
- ///
- /// Create a new llama_context with the given model.
- /// Return value should always be wrapped in SafeLLamaContextHandle!
- ///
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);
-
///
/// Initialize the llama + ggml backend
/// Call once at the start of the program
@@ -96,20 +79,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);
- ///
- /// Frees all allocated memory in the given llama_context
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_free(IntPtr ctx);
-
- ///
- /// Frees all allocated memory associated with a model
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_free_model(IntPtr model);
-
///
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 98b51078..b6c8eca6 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,5 +1,6 @@
using System;
using System.Buffers;
+using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
@@ -8,6 +9,7 @@ namespace LLama.Native
///
/// A safe wrapper around a llama_context
///
+ // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
@@ -36,26 +38,10 @@ namespace LLama.Native
#endregion
#region construction/destruction
- ///
- /// Create a new SafeLLamaContextHandle
- ///
- /// pointer to an allocated llama_context
- /// the model which this context was created from
- public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
- : base(handle)
- {
- // Increment the model reference count while this context exists
- _model = model;
- var success = false;
- _model.DangerousAddRef(ref success);
- if (!success)
- throw new RuntimeError("Failed to increment model refcount");
- }
-
///
protected override bool ReleaseHandle()
{
- NativeApi.llama_free(DangerousGetHandle());
+ llama_free(handle);
SetHandle(IntPtr.Zero);
// Decrement refcount on model
@@ -84,12 +70,42 @@ namespace LLama.Native
///
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
- var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
- if (ctx_ptr == IntPtr.Zero)
+ var ctx = llama_new_context_with_model(model, lparams);
+ if (ctx == null)
throw new RuntimeError("Failed to create context from model");
- return new(ctx_ptr, model);
+ // Increment the model reference count while this context exists.
+ // DangerousAddRef throws if it fails, so there is no need to check "success"
+ ctx._model = model;
+ var success = false;
+ ctx._model.DangerousAddRef(ref success);
+
+ return ctx;
+ }
+ #endregion
+
+ #region Native API
+ static SafeLLamaContextHandle()
+ {
+ // This ensures that `NativeApi` has been loaded before calling the two native methods below
+ NativeApi.llama_empty_call();
}
+
+ ///
+ /// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**!
+ ///
+ ///
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);
+
+ ///
+ /// Frees all allocated memory in the given llama_context
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_free(IntPtr ctx);
#endregion
///
diff --git a/LLama/Native/SafeLLamaHandleBase.cs b/LLama/Native/SafeLLamaHandleBase.cs
index 6371b327..7171c803 100644
--- a/LLama/Native/SafeLLamaHandleBase.cs
+++ b/LLama/Native/SafeLLamaHandleBase.cs
@@ -31,6 +31,6 @@ namespace LLama.Native
///
public override string ToString()
- => $"0x{handle.ToString("x16")}";
+ => $"0x{handle:x16}";
}
}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 291cfbc2..1e9fbca3 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
+using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
@@ -10,6 +11,7 @@ namespace LLama.Native
///
/// A reference to a set of llama model weights
///
+ // ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
@@ -47,8 +49,7 @@ namespace LLama.Native
///
protected override bool ReleaseHandle()
{
- NativeApi.llama_free_model(DangerousGetHandle());
- SetHandle(IntPtr.Zero);
+ llama_free_model(handle);
return true;
}
@@ -61,13 +62,37 @@ namespace LLama.Native
///
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
- var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
- if (model_ptr == null)
+ var model = llama_load_model_from_file(modelPath, lparams);
+ if (model == null)
throw new RuntimeError($"Failed to load model {modelPath}.");
- return model_ptr;
+ return model;
}
+ #region native API
+ static SafeLlamaModelHandle()
+ {
+ // This ensures that `NativeApi` has been loaded before calling the two native methods below
+ NativeApi.llama_empty_call();
+ }
+
+ ///
+ /// Load all of the weights of a model into memory.
+ ///
+ ///
+ ///
+ /// The loaded model, or null on failure.
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);
+
+ ///
+ /// Frees all allocated memory associated with a model
+ ///
+ ///
+ [DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
+ private static extern void llama_free_model(IntPtr model);
+ #endregion
+
#region LoRA
///