Browse Source

Moved some native methods to do with creating/destroying resources into their respective handles. There is **no** safe way to call most of these methods, everything must be done through through handles.

tags/v0.10.0
Martin Evans 2 years ago
parent
commit
1e69e265b6
5 changed files with 72 additions and 62 deletions
  1. +1
    -1
      LLama/Native/NativeApi.Load.cs
  2. +4
    -35
      LLama/Native/NativeApi.cs
  3. +36
    -20
      LLama/Native/SafeLLamaContextHandle.cs
  4. +1
    -1
      LLama/Native/SafeLLamaHandleBase.cs
  5. +30
    -5
      LLama/Native/SafeLlamaModelHandle.cs

+ 1
- 1
LLama/Native/NativeApi.Load.cs View File

@@ -329,7 +329,7 @@ namespace LLama.Native
#endif
}

private const string libraryName = "libllama";
internal const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";
private const string loggingPrefix = "[LLamaSharp Native]";
private static bool enableLogging = false;


+ 4
- 35
LLama/Native/NativeApi.cs View File

@@ -25,8 +25,10 @@ namespace LLama.Native
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
/// </summary>
/// <returns></returns>
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
public static void llama_empty_call()
{
llama_mmap_supported();
}

/// <summary>
/// Get the maximum number of devices supported by llama.cpp
@@ -70,25 +72,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported();

/// <summary>
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Create a new llama_context with the given model.
/// Return value should always be wrapped in SafeLLamaContextHandle!
/// </summary>
/// <param name="model"></param>
/// <param name="params"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// Initialize the llama + ggml backend
/// Call once at the start of the program
@@ -96,20 +79,6 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);

/// <summary>
/// Frees all allocated memory in the given llama_context
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free(IntPtr ctx);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model);

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for


+ 36
- 20
LLama/Native/SafeLLamaContextHandle.cs View File

@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;

@@ -8,6 +9,7 @@ namespace LLama.Native
/// <summary>
/// A safe wrapper around a llama_context
/// </summary>
// ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
@@ -36,26 +38,10 @@ namespace LLama.Native
#endregion

#region construction/destruction
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
/// <param name="handle">pointer to an allocated llama_context</param>
/// <param name="model">the model which this context was created from</param>
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
// Increment the model reference count while this context exists
_model = model;
var success = false;
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free(DangerousGetHandle());
llama_free(handle);
SetHandle(IntPtr.Zero);

// Decrement refcount on model
@@ -84,12 +70,42 @@ namespace LLama.Native
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
if (ctx_ptr == IntPtr.Zero)
var ctx = llama_new_context_with_model(model, lparams);
if (ctx == null)
throw new RuntimeError("Failed to create context from model");

return new(ctx_ptr, model);
// Increment the model reference count while this context exists.
// DangerousAddRef throws if it fails, so there is no need to check "success"
ctx._model = model;
var success = false;
ctx._model.DangerousAddRef(ref success);

return ctx;
}
#endregion

#region Native API
static SafeLLamaContextHandle()
{
// This ensures that `NativeApi` has been loaded before calling the two native methods below
NativeApi.llama_empty_call();
}

/// <summary>
/// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**!
/// </summary>
/// <param name="model"></param>
/// <param name="params"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// Frees all allocated memory in the given llama_context
/// </summary>
/// <param name="ctx"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free(IntPtr ctx);
#endregion

/// <summary>


+ 1
- 1
LLama/Native/SafeLLamaHandleBase.cs View File

@@ -31,6 +31,6 @@ namespace LLama.Native

/// <inheritdoc />
public override string ToString()
=> $"0x{handle.ToString("x16")}";
=> $"0x{handle:x16}";
}
}

+ 30
- 5
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
@@ -10,6 +11,7 @@ namespace LLama.Native
/// <summary>
/// A reference to a set of llama model weights
/// </summary>
// ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
@@ -47,8 +49,7 @@ namespace LLama.Native
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(DangerousGetHandle());
SetHandle(IntPtr.Zero);
llama_free_model(handle);
return true;
}

@@ -61,13 +62,37 @@ namespace LLama.Native
/// <exception cref="RuntimeError"></exception>
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == null)
var model = llama_load_model_from_file(modelPath, lparams);
if (model == null)
throw new RuntimeError($"Failed to load model {modelPath}.");

return model_ptr;
return model;
}

#region native API
static SafeLlamaModelHandle()
{
// This ensures that `NativeApi` has been loaded before calling the two native methods below
NativeApi.llama_empty_call();
}

/// <summary>
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free_model(IntPtr model);
#endregion

#region LoRA

/// <summary>


Loading…
Cancel
Save