From efdf3d630ce1ef010c69c4337f75f9a3bef4fdd9 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Sun, 22 Oct 2023 21:43:36 +0100 Subject: [PATCH] - Removed all `TokenToString` methods (it's never correct to use them, because sometimes one single character may be represented by multiple tokens). - Built a new (hacky) `Detokenize` method which handles this --- LLama.Unittest/TokenTests.cs | 2 +- LLama/LLamaContext.cs | 28 +---- LLama/LLamaExecutorBase.cs | 5 +- LLama/LLamaStatelessExecutor.cs | 2 +- LLama/Native/SafeLLamaContextHandle.cs | 50 +++++---- LLama/Native/SafeLlamaModelHandle.cs | 142 +++++++------------------ 6 files changed, 73 insertions(+), 156 deletions(-) diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index 3ba3b1cd..383428af 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -79,7 +79,7 @@ public sealed class TokenTests var strings = new[] { "Hello world", - "철수라는", + "철수", "😀 😃 😄 😁 😆 😅 😂 😊 😇 🙂 ", }; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index a190c075..7edf62c4 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -102,13 +102,9 @@ namespace LLama /// /// /// - public string DeTokenize(IEnumerable tokens) + public string DeTokenize(IReadOnlyList tokens) { - var sb = new StringBuilder(); - foreach (var token in tokens) - NativeHandle.TokenToString(token, Encoding, sb); - - return sb.ToString(); + return NativeHandle.DeTokenize(tokens, Encoding); } /// @@ -418,26 +414,6 @@ namespace LLama } #endregion - /// - /// Convert a token into a string - /// - /// - /// - public string TokenToString(llama_token token) - { - return NativeHandle.TokenToString(token, Encoding); - } - - /// - /// Append a single token to a string builder - /// - /// Token to decode - /// string builder to append the result to - public void TokenToString(llama_token token, StringBuilder dest) - { - NativeHandle.TokenToString(token, Encoding, dest); - } - /// public void Dispose() { diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index 1a12c6b2..578bd4d8 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -294,10 +294,7 @@ namespace LLama await InferInternal(inferenceParams, args); if (args.ReturnValue) - { - foreach (var id in _embeds) - yield return Context.TokenToString(id); - } + yield return Context.DeTokenize(_embeds); var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); if (extraOutputs is { Count: > 0 }) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 80488b71..457b3894 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -95,7 +95,7 @@ namespace LLama inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); lastTokens.Add(id); - yield return Context.TokenToString(id); + yield return Context.DeTokenize(new [] { id }); //todo: not correct to return tokens one by one like this! tokens.Clear(); tokens.Add(id); diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index c411385c..160b8cc8 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Text; using LLama.Exceptions; @@ -159,38 +160,43 @@ namespace LLama.Native } /// - /// Convert a token into a string + /// Convert a single llama token into bytes /// - /// Token to decode into a string - /// - /// - public string TokenToString(int token, Encoding encoding) + /// Token to decode + /// A span to attempt to write into. If this is too small nothing will be written + /// The size of this token. **nothing will be written** if this is larger than `dest` + public int TokenToSpan(int token, Span dest) { - return ThrowIfDisposed().TokenToString(token, encoding); + return ThrowIfDisposed().TokenToSpan(token, dest); } /// - /// Append a single llama token to a string builder + /// Convert a set of tokens into a string /// - /// Token to decode + /// /// - /// string builder to append the result to - public void TokenToString(int token, Encoding encoding, StringBuilder dest) + /// + public string DeTokenize(IReadOnlyList tokens, Encoding encoding) { - ThrowIfDisposed().TokenToString(token, encoding, dest); - } + var chars = ArrayPool.Shared.Rent(tokens.Count * 2); + try + { + var span = ThrowIfDisposed().TokensToSpan(tokens, chars.AsSpan(), encoding); + if (span.Length == 0) + return ""; - /// - /// Convert a single llama token into bytes - /// - /// Token to decode - /// A span to attempt to write into. If this is too small nothing will be written - /// The size of this token. **nothing will be written** if this is larger than `dest` - public int TokenToSpan(int token, Span dest) - { - return ThrowIfDisposed().TokenToSpan(token, dest); + unsafe + { + fixed (char* ptr = &span[0]) + return new string(ptr, 0, span.Length); + } + } + finally + { + ArrayPool.Shared.Return(chars); + } } - #endregion +#endregion /// /// Run the llama inference to obtain the logits and probabilities for the next token. diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 79beed61..7afcc3af 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -2,6 +2,7 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; +using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; using LLama.Extensions; @@ -118,66 +119,6 @@ namespace LLama.Native } } - /// - /// Convert a single llama token into a string - /// - /// - /// Encoding to use to decode the bytes into a string - /// - public string TokenToString(int llama_token, Encoding encoding) - { - unsafe - { - var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); - if (length == 0) - return ""; - - Span bytes = stackalloc byte[-length]; - - fixed (byte* bytePtr = bytes) - { - var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); - Debug.Assert(written == bytes.Length); - - return encoding.GetString(bytePtr, bytes.Length); - } - } - } - - /// - /// Append a single llama token to a string builder - /// - /// Token to decode - /// - /// string builder to append the result to - public void TokenToString(int llama_token, Encoding encoding, StringBuilder dest) - { - unsafe - { - var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); - if (length == 0) - return; - - Span bytes = stackalloc byte[-length]; - fixed (byte* bytePtr = bytes) - { - // Decode into bytes - var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); - Debug.Assert(written == bytes.Length); - - // Decode into chars - var charCount = encoding.GetCharCount(bytePtr, bytes.Length); - Span chars = stackalloc char[charCount]; - fixed (char* charPtr = chars) - encoding.GetChars(bytePtr, bytes.Length, charPtr, chars.Length); - - // Write it to the output - for (var i = 0; i < chars.Length; i++) - dest.Append(chars[i]); - } - } - } - /// /// Convert a sequence of tokens into characters. /// @@ -192,42 +133,52 @@ namespace LLama.Native { // Rent an array to detokenize into var tokenBytesArr = ArrayPool.Shared.Rent(16); - var tokenCharsArr = ArrayPool.Shared.Rent(16); - try + + // Convert all of the tokens into bytes + var bytes = new List(); + foreach (var token in tokens) { - var totalCharacters = 0; - var unused = dest; + var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); + foreach (var tokenByte in tokenBytes) + bytes.Add(tokenByte); + } - for (var i = tokens.Count - 1; i >= 0; i--) + // Extract a span from the list + var bytesSpan = +#if NETSTANDARD2_0 + bytes.ToArray().AsSpan(); +#else + CollectionsMarshal.AsSpan(bytes); +#endif + + // Check how many characters these bytes represent. If there's not enough space in the + // output array we need to handle that. + var characterCount = encoding.GetCharCount(bytesSpan); + if (characterCount > dest.Length) + { + var bigChars = ArrayPool.Shared.Rent(characterCount); + try { - var token = tokens[i]; - - // Get bytes for this token - var tokenBytes = TokenToBytes(ref tokenBytesArr, token, this); - - // Get chars for this token - var tokenChars = BytesToChars(ref tokenCharsArr, tokenBytes, encoding); + encoding.GetChars(bytesSpan, bigChars); + var charSlice = bigChars + .AsSpan(0, characterCount) + .Slice(characterCount - dest.Length); - // Trim down number of characters if there are too many - if (tokenChars.Length > unused.Length) - tokenChars = tokenChars.Slice(tokenChars.Length - unused.Length, unused.Length); - - // Copy characters - tokenChars.CopyTo(unused.Slice(unused.Length - tokenChars.Length, tokenChars.Length)); - unused = unused.Slice(0, unused.Length - tokenChars.Length); - totalCharacters += tokenChars.Length; - - // Break out if we've run out of space - if (unused.Length == 0) - break; + charSlice.CopyTo(dest); + return dest; + } + finally + { + ArrayPool.Shared.Return(bigChars); } - return dest.Slice(dest.Length - totalCharacters, totalCharacters); + //todo: handle dest span too small + throw new NotImplementedException(); } - finally + else { - ArrayPool.Shared.Return(tokenBytesArr); - ArrayPool.Shared.Return(tokenCharsArr); + var charCount = encoding.GetChars(bytes.ToArray(), dest); + return dest.Slice(0, charCount); } // vvv Local Functions vvv @@ -250,19 +201,6 @@ namespace LLama.Native Debug.Assert(l >= 0); return new Span(bytes, 0, l); } - - static Span BytesToChars(ref char[] chars, ReadOnlySpan bytes, Encoding encoding) - { - var count = encoding.GetCharCount(bytes); - if (count > chars.Length) - { - ArrayPool.Shared.Return(chars); - chars = ArrayPool.Shared.Rent(count * 2); - } - - encoding.GetChars(bytes, chars); - return chars.AsSpan(0, count); - } } /// @@ -304,7 +242,7 @@ namespace LLama.Native } } } - #endregion +#endregion #region context ///