diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index c589a270..832f3fdd 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -10,7 +10,7 @@ namespace LLama.Unittest
public BasicTest()
{
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};
diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
new file mode 100644
index 00000000..ea054bc8
--- /dev/null
+++ b/LLama.Unittest/Constants.cs
@@ -0,0 +1,7 @@
+namespace LLama.Unittest
+{
+ internal static class Constants
+ {
+ public static string ModelPath = "Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin";
+ }
+}
diff --git a/LLama.Unittest/GrammarTest.cs b/LLama.Unittest/GrammarTest.cs
index 482268ea..05791527 100644
--- a/LLama.Unittest/GrammarTest.cs
+++ b/LLama.Unittest/GrammarTest.cs
@@ -11,7 +11,7 @@ namespace LLama.Unittest
public GrammarTest()
{
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048,
};
diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index c6d2dc21..e9c84eac 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -10,7 +10,7 @@ namespace LLama.Unittest
public LLamaContextTests()
{
- var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ var @params = new ModelParams(Constants.ModelPath)
{
ContextSize = 768,
};
diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
index 03487353..f94c90ba 100644
--- a/LLama.Unittest/LLamaEmbedderTests.cs
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -5,7 +5,7 @@ namespace LLama.Unittest;
public class LLamaEmbedderTests
: IDisposable
{
- private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));
+ private readonly LLamaEmbedder _embedder = new(new ModelParams(Constants.ModelPath));
public void Dispose()
{
@@ -36,18 +36,19 @@ public class LLamaEmbedderTests
Assert.Equal(expected[i], actual[i], epsilon);
}
- [Fact]
- public void EmbedBasic()
- {
- var cat = _embedder.GetEmbeddings("cat");
+ // todo: enable this one llama2 7B gguf is available
+ //[Fact]
+ //public void EmbedBasic()
+ //{
+ // var cat = _embedder.GetEmbeddings("cat");
- Assert.NotNull(cat);
- Assert.NotEmpty(cat);
+ // Assert.NotNull(cat);
+ // Assert.NotEmpty(cat);
- // Expected value generate with llama.cpp embedding.exe
- var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
- AssertApproxStartsWith(expected, cat);
- }
+ // // Expected value generate with llama.cpp embedding.exe
+ // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
+ // AssertApproxStartsWith(expected, cat);
+ //}
[Fact]
public void EmbedCompare()
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index 37031da3..1748e02d 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -13,7 +13,7 @@ namespace LLama.Unittest
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
- _params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ _params = new ModelParams(Constants.ModelPath)
{
ContextSize = 60,
Seed = 1754
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index fbb2107c..3f27d1ba 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -365,7 +365,7 @@ namespace LLama
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl();
+ var nl_token = NativeApi.llama_token_nl(_ctx);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 1a84ad2f..bcbc2998 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -162,7 +162,7 @@ namespace LLama
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
args.WaitForInput = true;
}
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 595ddb3b..aa184ca3 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -154,7 +154,7 @@ namespace LLama
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
extraOutputs = new[] { " [end of text]\n" };
return true;
@@ -215,7 +215,7 @@ namespace LLama
_last_n_tokens.Enqueue(id);
- if (id == NativeApi.llama_token_eos())
+ if (id == NativeApi.llama_token_eos(Context.NativeHandle))
{
id = _llama_token_newline.First();
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 01b6a43d..c5c8d786 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -340,25 +340,25 @@ namespace LLama.Native
public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token);
///
- /// Get the "Beginning of string" token
+ /// Get the "Beginning of sentence" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_bos();
+ public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
///
- /// Get the "End of string" token
+ /// Get the "End of sentence" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_eos();
+ public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
///
/// Get the "new line" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_nl();
+ public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
///
/// Print out timing information for this context
@@ -410,9 +410,11 @@ namespace LLama.Native
///
///
///
- ///
+ /// buffer to write string into
+ /// size of the buffer
+ /// The length writte, or if the buffer is too small a negative that indicates the length required
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
+ public static extern int llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
///
/// Convert text into tokens
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index 130e7c85..059bb070 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -1,4 +1,5 @@
using System;
+using System.Diagnostics;
using System.Text;
using LLama.Exceptions;
@@ -90,9 +91,16 @@ namespace LLama.Native
{
unsafe
{
- var bytes = new ReadOnlySpan(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
- var terminator = bytes.IndexOf((byte)0);
- return bytes.Slice(0, terminator);
+ var length = NativeApi.llama_token_to_str_with_model(this, llama_token, null, 0);
+ var bytes = new byte[-length];
+
+ fixed (byte* bytePtr = bytes)
+ {
+ var written = NativeApi.llama_token_to_str_with_model(this, llama_token, bytePtr, bytes.Length);
+ Debug.Assert(written == bytes.Length);
+ }
+
+ return new ReadOnlySpan(bytes);
}
}
@@ -104,16 +112,20 @@ namespace LLama.Native
///
public string TokenToString(int llama_token, Encoding encoding)
{
- var span = TokenToSpan(llama_token);
-
- if (span.Length == 0)
- return "";
-
unsafe
{
- fixed (byte* ptr = &span[0])
+ var length = NativeApi.llama_token_to_str_with_model(this, llama_token, null, 0);
+ if (length == 0)
+ return "";
+
+ Span bytes = stackalloc byte[-length];
+
+ fixed (byte* bytePtr = bytes)
{
- return encoding.GetString(ptr, span.Length);
+ var written = NativeApi.llama_token_to_str_with_model(this, llama_token, bytePtr, bytes.Length);
+ Debug.Assert(written == bytes.Length);
+
+ return encoding.GetString(bytePtr, bytes.Length);
}
}
}
diff --git a/LLama/OldVersion/LLamaModel.cs b/LLama/OldVersion/LLamaModel.cs
index ec528ec4..523b9553 100644
--- a/LLama/OldVersion/LLamaModel.cs
+++ b/LLama/OldVersion/LLamaModel.cs
@@ -634,7 +634,7 @@ namespace LLama.OldVersion
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);
// Apply penalties
- float nl_logit = logits[NativeApi.llama_token_nl()];
+ float nl_logit = logits[NativeApi.llama_token_nl(_ctx)];
var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
_last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
@@ -644,7 +644,7 @@ namespace LLama.OldVersion
(ulong)last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
- logits[NativeApi.llama_token_nl()] = nl_logit;
+ logits[NativeApi.llama_token_nl(_ctx)] = nl_logit;
}
if (temp <= 0)
@@ -684,7 +684,7 @@ namespace LLama.OldVersion
}
// replace end of text token with newline token when in interactive mode
- if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
+ if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct)
{
id = _llama_token_newline[0];
if (_params.antiprompt.Count != 0)
@@ -760,7 +760,7 @@ namespace LLama.OldVersion
break;
}
- if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
+ if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx))
{
if (_params.instruct)
{