diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs
index ee2936d0..59e8869a 100644
--- a/LLama.Examples/Examples/BatchedDecoding.cs
+++ b/LLama.Examples/Examples/BatchedDecoding.cs
@@ -52,13 +52,13 @@ public class BatchedDecoding
return;
}
- var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
+ var batch = new LLamaBatch(1);
// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
- batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
+ batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
- if (context.NativeHandle.Decode(batch) != 0)
+ if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
@@ -97,7 +97,7 @@ public class BatchedDecoding
timer.Start();
while (n_cur <= n_len)
{
- batch.LLamaBatchClear();
+ batch.Clear();
for (var i = 0; i < n_parallel; i++)
{
@@ -129,7 +129,7 @@ public class BatchedDecoding
i_batch[i] = batch.TokenCount;
// push this new token for next evaluation
- batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
+ batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
n_decode++;
}
@@ -143,7 +143,7 @@ public class BatchedDecoding
n_cur++;
// evaluate the current batch with the transformer model
- if (context.NativeHandle.Decode(batch) != 0)
+ if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs
index 3014894e..99486088 100644
--- a/LLama.Unittest/BeamTests.cs
+++ b/LLama.Unittest/BeamTests.cs
@@ -40,7 +40,7 @@ public sealed class BeamTests
var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
- context.Eval(initial_tokens, 0);
+ context.Eval(initial_tokens.AsSpan(), 0);
NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index 72e9acf8..8d4be20c 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -36,7 +36,7 @@ namespace LLama.Unittest
var executor = new StatelessExecutor(_weights, _params);
- const string question = "Question. what is a cat?\nAnswer: ";
+ const string question = "Question. what is a cat?\nAnswer:";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };
var timer = new Stopwatch();
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index dd3d081a..ea745d02 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -8,10 +8,12 @@ using System.IO;
using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
+using System.Threading.Tasks;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
+using System.Threading;
namespace LLama
{
@@ -344,16 +346,30 @@ namespace LLama
#region eval overloads
///
- ///
///
- ///
- ///
- /// The updated `pastTokensCount`.
- ///
- [Obsolete("use llama_decode() instead")]
- public int Eval(LLamaToken[] tokens, int pastTokensCount)
+ ///
+ /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ public int Decode(LLamaBatch batch)
+ {
+ return NativeHandle.Decode(batch);
+ }
+
+ ///
+ ///
+ ///
+ ///
+ /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
- return Eval(tokens.AsSpan(), pastTokensCount);
+ return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
}
///
@@ -363,7 +379,7 @@ namespace LLama
///
/// The updated `pastTokensCount`.
///
- [Obsolete("use llama_decode() instead")]
+ [Obsolete("use Decode() instead")]
public int Eval(List tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
@@ -394,20 +410,7 @@ namespace LLama
///
/// The updated `pastTokensCount`.
///
- [Obsolete("use llama_decode() instead")]
- public int Eval(ReadOnlyMemory tokens, int pastTokensCount)
- {
- return Eval(tokens.Span, pastTokensCount);
- }
-
- ///
- ///
- ///
- ///
- ///
- /// The updated `pastTokensCount`.
- ///
- [Obsolete("use llama_decode() instead")]
+ [Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan tokens, int pastTokensCount)
{
var total = tokens.Length;
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 0c6cc87c..bccfd141 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -75,7 +75,7 @@ namespace LLama
// TODO(Rinne): deal with log of prompt
if (embed_inp_array.Length > 0)
- Context.Eval(embed_inp_array, 0);
+ Context.Eval(embed_inp_array.AsSpan(), 0);
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs
index e4dc6af6..9abb15ae 100644
--- a/LLama/Native/LLamaBatch.cs
+++ b/LLama/Native/LLamaBatch.cs
@@ -1,4 +1,5 @@
using System;
+using System.Collections.Generic;
namespace LLama.Native;
@@ -7,27 +8,42 @@ namespace LLama.Native;
///
public class LLamaBatch
{
- private readonly byte[] _logits;
+ private byte[] _logits;
- private readonly LLamaToken[] _tokens;
- private readonly LLamaPos[] _positions;
+ private LLamaToken[] _tokens;
+ private LLamaPos[] _positions;
- private readonly int[] _sequenceIdCount;
- private readonly LLamaSeqId[][] _sequenceIds;
- private readonly IntPtr[] _sequenceIdsPtrs;
+ private int[] _sequenceIdCount;
+ private LLamaSeqId[][] _sequenceIds;
+ private IntPtr[] _sequenceIdsPtrs;
///
/// The number of tokens in this batch
///
public int TokenCount { get; private set; }
+ ///
+ /// Maximum number of tokens that can be added to this batch
+ ///
+ private int TokenCapacity { get; set; }
+
+ ///
+ /// Maximum number of sequences a token can be assigned to
+ ///
+ public int MaxSequences { get; private set; }
+
///
/// Create a new batch for submitting inputs to llama.cpp
///
- ///
- ///
- public LLamaBatch(int n_tokens, int n_seq_max)
+ /// Max number of sequences a token can be assigned to
+ public LLamaBatch(int n_seq_max)
{
+ // The number of tokens can be grown later, start off with a reasonable guess.
+ const int n_tokens = 64;
+
+ MaxSequences = n_seq_max;
+ TokenCapacity = n_tokens;
+
_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
@@ -37,7 +53,29 @@ public class LLamaBatch
_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
- _sequenceIds[i] = new LLamaSeqId[n_seq_max];
+ _sequenceIds[i] = new LLamaSeqId[MaxSequences];
+ }
+
+ private void Grow()
+ {
+ var n_tokens = TokenCount * 2;
+ TokenCapacity = n_tokens;
+
+ Array.Resize(ref _logits, n_tokens);
+ Array.Resize(ref _tokens, n_tokens);
+ Array.Resize(ref _positions, n_tokens);
+
+ Array.Resize(ref _sequenceIdCount, n_tokens);
+ Array.Resize(ref _sequenceIdsPtrs, n_tokens);
+
+ Array.Resize(ref _sequenceIds, n_tokens);
+ for (int i = 0; i < _sequenceIds.Length; i++)
+ {
+ // Growing the array filled elements with null, temporarily violating the nullability contract!
+ // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
+ if (_sequenceIds[i] == null)
+ _sequenceIds[i] = new LLamaSeqId[MaxSequences];
+ }
}
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
@@ -79,8 +117,11 @@ public class LLamaBatch
/// The position to add it att
/// The set of sequences to add this token to
///
- public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
+ public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
{
+ if (TokenCount == TokenCapacity)
+ Grow();
+
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;
@@ -101,20 +142,20 @@ public class LLamaBatch
/// The position to add it att
/// The sequence to add this token to
///
- public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
+ public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;
// Add it
- LLamaBatchAdd(token, pos, sequences, logits);
+ Add(token, pos, sequences, logits);
}
///
/// Set TokenCount to zero for this batch
///
- public void LLamaBatchClear()
+ public void Clear()
{
TokenCount = 0;
}