diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs
index 3ba3b1cd..383428af 100644
--- a/LLama.Unittest/TokenTests.cs
+++ b/LLama.Unittest/TokenTests.cs
@@ -79,7 +79,7 @@ public sealed class TokenTests
var strings = new[]
{
"Hello world",
- "철수라는",
+ "철수",
"😀 😃 😄 😁 😆 😅 😂 😊 😇 🙂 ",
};
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index a190c075..7edf62c4 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -102,13 +102,9 @@ namespace LLama
///
///
///
- public string DeTokenize(IEnumerable tokens)
+ public string DeTokenize(IReadOnlyList tokens)
{
- var sb = new StringBuilder();
- foreach (var token in tokens)
- NativeHandle.TokenToString(token, Encoding, sb);
-
- return sb.ToString();
+ return NativeHandle.DeTokenize(tokens, Encoding);
}
///
@@ -418,26 +414,6 @@ namespace LLama
}
#endregion
- ///
- /// Convert a token into a string
- ///
- ///
- ///
- public string TokenToString(llama_token token)
- {
- return NativeHandle.TokenToString(token, Encoding);
- }
-
- ///
- /// Append a single token to a string builder
- ///
- /// Token to decode
- /// string builder to append the result to
- public void TokenToString(llama_token token, StringBuilder dest)
- {
- NativeHandle.TokenToString(token, Encoding, dest);
- }
-
///
public void Dispose()
{
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 1a12c6b2..578bd4d8 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -294,10 +294,7 @@ namespace LLama
await InferInternal(inferenceParams, args);
if (args.ReturnValue)
- {
- foreach (var id in _embeds)
- yield return Context.TokenToString(id);
- }
+ yield return Context.DeTokenize(_embeds);
var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args);
if (extraOutputs is { Count: > 0 })
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 80488b71..457b3894 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -95,7 +95,7 @@ namespace LLama
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
lastTokens.Add(id);
- yield return Context.TokenToString(id);
+ yield return Context.DeTokenize(new [] { id }); //todo: not correct to return tokens one by one like this!
tokens.Clear();
tokens.Add(id);
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index c411385c..160b8cc8 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -1,5 +1,6 @@
using System;
using System.Buffers;
+using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
@@ -159,38 +160,43 @@ namespace LLama.Native
}
///
- /// Convert a token into a string
+ /// Convert a single llama token into bytes
///
- /// Token to decode into a string
- ///
- ///
- public string TokenToString(int token, Encoding encoding)
+ /// Token to decode
+ /// A span to attempt to write into. If this is too small nothing will be written
+ /// The size of this token. **nothing will be written** if this is larger than `dest`
+ public int TokenToSpan(int token, Span dest)
{
- return ThrowIfDisposed().TokenToString(token, encoding);
+ return ThrowIfDisposed().TokenToSpan(token, dest);
}
///
- /// Append a single llama token to a string builder
+ /// Convert a set of tokens into a string
///
- /// Token to decode
+ ///
///
- /// string builder to append the result to
- public void TokenToString(int token, Encoding encoding, StringBuilder dest)
+ ///
+ public string DeTokenize(IReadOnlyList tokens, Encoding encoding)
{
- ThrowIfDisposed().TokenToString(token, encoding, dest);
- }
+ var chars = ArrayPool.Shared.Rent(tokens.Count * 2);
+ try
+ {
+ var span = ThrowIfDisposed().TokensToSpan(tokens, chars.AsSpan(), encoding);
+ if (span.Length == 0)
+ return "";
- ///
- /// Convert a single llama token into bytes
- ///
- /// Token to decode
- /// A span to attempt to write into. If this is too small nothing will be written
- /// The size of this token. **nothing will be written** if this is larger than `dest`
- public int TokenToSpan(int token, Span dest)
- {
- return ThrowIfDisposed().TokenToSpan(token, dest);
+ unsafe
+ {
+ fixed (char* ptr = &span[0])
+ return new string(ptr, 0, span.Length);
+ }
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(chars);
+ }
}
- #endregion
+#endregion
///
/// Run the llama inference to obtain the logits and probabilities for the next token.
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 79beed61..7afcc3af 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -2,6 +2,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
+using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
@@ -118,66 +119,6 @@ namespace LLama.Native
}
}
- ///
- /// Convert a single llama token into a string
- ///
- ///
- /// Encoding to use to decode the bytes into a string
- ///
- public string TokenToString(int llama_token, Encoding encoding)
- {
- unsafe
- {
- var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
- if (length == 0)
- return "";
-
- Span bytes = stackalloc byte[-length];
-
- fixed (byte* bytePtr = bytes)
- {
- var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
- Debug.Assert(written == bytes.Length);
-
- return encoding.GetString(bytePtr, bytes.Length);
- }
- }
- }
-
- ///
- /// Append a single llama token to a string builder
- ///
- /// Token to decode
- ///
- /// string builder to append the result to
- public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest)
- {
- unsafe
- {
- var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0);
- if (length == 0)
- return;
-
- Span bytes = stackalloc byte[-length];
- fixed (byte* bytePtr = bytes)
- {
- // Decode into bytes
- var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length);
- Debug.Assert(written == bytes.Length);
-
- // Decode into chars
- var charCount = encoding.GetCharCount(bytePtr, bytes.Length);
- Span chars = stackalloc char[charCount];
- fixed (char* charPtr = chars)
- encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length);
-
- // Write it to the output
- for (var i = 0; i < chars.Length; i++)
- dest.Append(chars[i]);
- }
- }
- }
-
///
/// Convert a sequence of tokens into characters.
///
@@ -192,42 +133,52 @@ namespace LLama.Native
{
// Rent an array to detokenize into
var tokenBytesArr = ArrayPool.Shared.Rent(16);
- var tokenCharsArr = ArrayPool.Shared.Rent(16);
- try
+
+ // Convert all of the tokens into bytes
+ var bytes = new List();
+ foreach (var token in tokens)
{
- var totalCharacters = 0;
- var unused = dest;
+ var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this);
+ foreach (var tokenByte in tokenBytes)
+ bytes.Add(tokenByte);
+ }
- for (var i = tokens.Count - 1; i >= 0; i--)
+ // Extract a span from the list
+ var bytesSpan =
+#if NETSTANDARD2_0
+ bytes.ToArray().AsSpan();
+#else
+ CollectionsMarshal.AsSpan(bytes);
+#endif
+
+ // Check how many characters these bytes represent. If there's not enough space in the
+ // output array we need to handle that.
+ var characterCount = encoding.GetCharCount(bytesSpan);
+ if (characterCount > dest.Length)
+ {
+ var bigChars = ArrayPool.Shared.Rent(characterCount);
+ try
{
- var token = tokens[i];
-
- // Get bytes for this token
- var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this);
-
- // Get chars for this token
- var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding);
+ encoding.GetChars(bytesSpan, bigChars);
+ var charSlice = bigChars
+ .AsSpan(0, characterCount)
+ .Slice(characterCount - dest.Length);
- // Trim down number of characters if there are too many
- if (tokenChars.Length > unused.Length)
- tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length);
-
- // Copy characters
- tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length));
- unused = unused.Slice(0, unused.Length - tokenChars.Length);
- totalCharacters += tokenChars.Length;
-
- // Break out if we've run out of space
- if (unused.Length == 0)
- break;
+ charSlice.CopyTo(dest);
+ return dest;
+ }
+ finally
+ {
+ ArrayPool.Shared.Return(bigChars);
}
- return dest.Slice(dest.Length - totalCharacters, totalCharacters);
+ //todo: handle dest span too small
+ throw new NotImplementedException();
}
- finally
+ else
{
- ArrayPool.Shared.Return(tokenBytesArr);
- ArrayPool.Shared.Return(tokenCharsArr);
+ var charCount = encoding.GetChars(bytes.ToArray(), dest);
+ return dest.Slice(0, charCount);
}
// vvv Local Functions vvv
@@ -250,19 +201,6 @@ namespace LLama.Native
Debug.Assert(l >= 0);
return new Span(bytes, 0, l);
}
-
- static Span BytesToChars(ref char[] chars, ReadOnlySpan bytes, Encoding encoding)
- {
- var count = encoding.GetCharCount(bytes);
- if (count > chars.Length)
- {
- ArrayPool.Shared.Return(chars);
- chars = ArrayPool.Shared.Rent(count * 2);
- }
-
- encoding.GetChars(bytes, chars);
- return chars.AsSpan(0, count);
- }
}
///
@@ -304,7 +242,7 @@ namespace LLama.Native
}
}
}
- #endregion
+#endregion
#region context
///