Browse Source

Passing ctx to `llama_token_nl(_ctx)`

tags/v0.5.1
Martin Evans 2 years ago
parent
commit
0c98ae1955
12 changed files with 62 additions and 40 deletions
  1. +1
    -1
      LLama.Unittest/BasicTest.cs
  2. +7
    -0
      LLama.Unittest/Constants.cs
  3. +1
    -1
      LLama.Unittest/GrammarTest.cs
  4. +1
    -1
      LLama.Unittest/LLamaContextTests.cs
  5. +12
    -11
      LLama.Unittest/LLamaEmbedderTests.cs
  6. +1
    -1
      LLama.Unittest/StatelessExecutorTest.cs
  7. +1
    -1
      LLama/LLamaContext.cs
  8. +1
    -1
      LLama/LLamaInstructExecutor.cs
  9. +2
    -2
      LLama/LLamaInteractExecutor.cs
  10. +9
    -7
      LLama/Native/NativeApi.cs
  11. +22
    -10
      LLama/Native/SafeLlamaModelHandle.cs
  12. +4
    -4
      LLama/OldVersion/LLamaModel.cs

+ 1
- 1
LLama.Unittest/BasicTest.cs View File

@@ -10,7 +10,7 @@ namespace LLama.Unittest

public BasicTest()
{
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048
};


+ 7
- 0
LLama.Unittest/Constants.cs View File

@@ -0,0 +1,7 @@
namespace LLama.Unittest
{
internal static class Constants
{
public static string ModelPath = "Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin";
}
}

+ 1
- 1
LLama.Unittest/GrammarTest.cs View File

@@ -11,7 +11,7 @@ namespace LLama.Unittest

public GrammarTest()
{
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 2048,
};


+ 1
- 1
LLama.Unittest/LLamaContextTests.cs View File

@@ -10,7 +10,7 @@ namespace LLama.Unittest

public LLamaContextTests()
{
var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
var @params = new ModelParams(Constants.ModelPath)
{
ContextSize = 768,
};


+ 12
- 11
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -5,7 +5,7 @@ namespace LLama.Unittest;
public class LLamaEmbedderTests
: IDisposable
{
private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));
private readonly LLamaEmbedder _embedder = new(new ModelParams(Constants.ModelPath));

public void Dispose()
{
@@ -36,18 +36,19 @@ public class LLamaEmbedderTests
Assert.Equal(expected[i], actual[i], epsilon);
}

[Fact]
public void EmbedBasic()
{
var cat = _embedder.GetEmbeddings("cat");
// todo: enable this one llama2 7B gguf is available
//[Fact]
//public void EmbedBasic()
//{
// var cat = _embedder.GetEmbeddings("cat");

Assert.NotNull(cat);
Assert.NotEmpty(cat);
// Assert.NotNull(cat);
// Assert.NotEmpty(cat);

// Expected value generate with llama.cpp embedding.exe
var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
AssertApproxStartsWith(expected, cat);
}
// // Expected value generate with llama.cpp embedding.exe
// var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
// AssertApproxStartsWith(expected, cat);
//}

[Fact]
public void EmbedCompare()


+ 1
- 1
LLama.Unittest/StatelessExecutorTest.cs View File

@@ -13,7 +13,7 @@ namespace LLama.Unittest
public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
_params = new ModelParams(Constants.ModelPath)
{
ContextSize = 60,
Seed = 1754


+ 1
- 1
LLama/LLamaContext.cs View File

@@ -365,7 +365,7 @@ namespace LLama
}

// Save the newline logit value
var nl_token = NativeApi.llama_token_nl();
var nl_token = NativeApi.llama_token_nl(_ctx);
var nl_logit = logits[nl_token];

// Convert logits into token candidates


+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -162,7 +162,7 @@ namespace LLama
}
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
args.WaitForInput = true;
}


+ 2
- 2
LLama/LLamaInteractExecutor.cs View File

@@ -154,7 +154,7 @@ namespace LLama
}
}

if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos())
if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
{
extraOutputs = new[] { " [end of text]\n" };
return true;
@@ -215,7 +215,7 @@ namespace LLama

_last_n_tokens.Enqueue(id);

if (id == NativeApi.llama_token_eos())
if (id == NativeApi.llama_token_eos(Context.NativeHandle))
{
id = _llama_token_newline.First();
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)


+ 9
- 7
LLama/Native/NativeApi.cs View File

@@ -340,25 +340,25 @@ namespace LLama.Native
public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token);

/// <summary>
/// Get the "Beginning of string" token
/// Get the "Beginning of sentence" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_bos();
public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the "End of string" token
/// Get the "End of sentence" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_eos();
public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the "new line" token
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern llama_token llama_token_nl();
public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);

/// <summary>
/// Print out timing information for this context
@@ -410,9 +410,11 @@ namespace LLama.Native
/// </summary>
/// <param name="model"></param>
/// <param name="llamaToken"></param>
/// <returns></returns>
/// <param name="buffer">buffer to write string into</param>
/// <param name="length">size of the buffer</param>
/// <returns>The length writte, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);
public static extern int llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);

/// <summary>
/// Convert text into tokens


+ 22
- 10
LLama/Native/SafeLlamaModelHandle.cs View File

@@ -1,4 +1,5 @@
using System;
using System.Diagnostics;
using System.Text;
using LLama.Exceptions;

@@ -90,9 +91,16 @@ namespace LLama.Native
{
unsafe
{
var bytes = new ReadOnlySpan<byte>(NativeApi.llama_token_to_str_with_model(this, llama_token), int.MaxValue);
var terminator = bytes.IndexOf((byte)0);
return bytes.Slice(0, terminator);
var length = NativeApi.llama_token_to_str_with_model(this, llama_token, null, 0);
var bytes = new byte[-length];

fixed (byte* bytePtr = bytes)
{
var written = NativeApi.llama_token_to_str_with_model(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);
}

return new ReadOnlySpan<byte>(bytes);
}
}

@@ -104,16 +112,20 @@ namespace LLama.Native
/// <returns></returns>
public string TokenToString(int llama_token, Encoding encoding)
{
var span = TokenToSpan(llama_token);

if (span.Length == 0)
return "";

unsafe
{
fixed (byte* ptr = &span[0])
var length = NativeApi.llama_token_to_str_with_model(this, llama_token, null, 0);
if (length == 0)
return "";

Span<byte> bytes = stackalloc byte[-length];

fixed (byte* bytePtr = bytes)
{
return encoding.GetString(ptr, span.Length);
var written = NativeApi.llama_token_to_str_with_model(this, llama_token, bytePtr, bytes.Length);
Debug.Assert(written == bytes.Length);

return encoding.GetString(bytePtr, bytes.Length);
}
}
}


+ 4
- 4
LLama/OldVersion/LLamaModel.cs View File

@@ -634,7 +634,7 @@ namespace LLama.OldVersion
LLamaTokenDataArray candidates_p = new LLamaTokenDataArray(candidates);

// Apply penalties
float nl_logit = logits[NativeApi.llama_token_nl()];
float nl_logit = logits[NativeApi.llama_token_nl(_ctx)];
var last_n_repeat = Math.Min(Math.Min(_last_n_tokens.Count, repeat_last_n), _n_ctx);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates_p,
_last_n_tokens.Skip(_last_n_tokens.Count - last_n_repeat).ToArray(),
@@ -644,7 +644,7 @@ namespace LLama.OldVersion
(ulong)last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl)
{
logits[NativeApi.llama_token_nl()] = nl_logit;
logits[NativeApi.llama_token_nl(_ctx)] = nl_logit;
}

if (temp <= 0)
@@ -684,7 +684,7 @@ namespace LLama.OldVersion
}

// replace end of text token with newline token when in interactive mode
if (id == NativeApi.llama_token_eos() && _params.interactive && !_params.instruct)
if (id == NativeApi.llama_token_eos(_ctx) && _params.interactive && !_params.instruct)
{
id = _llama_token_newline[0];
if (_params.antiprompt.Count != 0)
@@ -760,7 +760,7 @@ namespace LLama.OldVersion
break;
}

if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos())
if (_embed.Count > 0 && _embed.Last() == NativeApi.llama_token_eos(_ctx))
{
if (_params.instruct)
{


Loading…
Cancel
Save