diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 40c5432b..fdc91152 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -93,7 +93,7 @@ namespace LLama.Abstractions
///
/// how split tensors should be distributed across GPUs
///
- nint TensorSplits { get; set; }
+ float[]? TensorSplits { get; set; }
///
/// Grouped-Query Attention
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 72c77937..5cb81078 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -1,14 +1,13 @@
using LLama.Abstractions;
using System;
-using System.Collections.Generic;
-using System.Text;
namespace LLama.Common
{
///
/// The parameters for initializing a LLama model.
///
- public class ModelParams : IModelParams
+ public class ModelParams
+ : IModelParams
{
///
/// Model context size (n_ctx)
@@ -85,7 +84,7 @@ namespace LLama.Common
///
/// how split tensors should be distributed across GPUs
///
- public nint TensorSplits { get; set; }
+ public float[]? TensorSplits { get; set; }
///
/// Grouped-Query Attention
diff --git a/LLama/Utils.cs b/LLama/Utils.cs
index 1454693f..7f05c1c7 100644
--- a/LLama/Utils.cs
+++ b/LLama/Utils.cs
@@ -15,8 +15,13 @@ namespace LLama
{
public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params)
{
- var lparams = NativeApi.llama_context_default_params();
+ if (!File.Exists(@params.ModelPath))
+ throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
+
+ if (@params.TensorSplits != null && @params.TensorSplits.Length != 1)
+ throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp.");
+ var lparams = NativeApi.llama_context_default_params();
lparams.n_ctx = @params.ContextSize;
lparams.n_batch = @params.BatchSize;
lparams.main_gpu = @params.MainGpu;
@@ -34,27 +39,21 @@ namespace LLama
lparams.rope_freq_scale = @params.RopeFrequencyScale;
lparams.mul_mat_q = @params.MulMatQ;
- /*
- if (@params.TensorSplits.Length != 1)
- {
- throw new ArgumentException("Currently multi-gpu support is not supported by " +
- "both llama.cpp and LLamaSharp.");
- }*/
-
- lparams.tensor_split = @params.TensorSplits;
-
- if (!File.Exists(@params.ModelPath))
+ unsafe
{
- throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}");
- }
+ fixed (float* splits = @params.TensorSplits)
+ {
+ lparams.tensor_split = (nint)splits;
- var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
- var ctx = SafeLLamaContextHandle.Create(model, lparams);
+ var model = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams);
+ var ctx = SafeLLamaContextHandle.Create(model, lparams);
- if (!string.IsNullOrEmpty(@params.LoraAdapter))
- model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
+ if (!string.IsNullOrEmpty(@params.LoraAdapter))
+ model.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads);
- return ctx;
+ return ctx;
+ }
+ }
}
public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding)