Browse Source

Merge pull request #311 from martindevans/Improved_Test_Coverage

Improved test coverage.
tags/v0.8.1
Martin Evans GitHub 2 years ago
parent
commit
c6d507040c
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 125 additions and 80 deletions
  1. +95
    -0
      LLama.Unittest/FixedSizeQueueTests.cs
  2. +19
    -2
      LLama.Unittest/GrammarTest.cs
  3. +1
    -21
      LLama.Unittest/LLamaEmbedderTests.cs
  4. +5
    -19
      LLama/Common/FixedSizeQueue.cs
  5. +1
    -1
      LLama/LLamaExecutorBase.cs
  6. +1
    -1
      LLama/LLamaInstructExecutor.cs
  7. +1
    -1
      LLama/LLamaInteractExecutor.cs
  8. +2
    -35
      LLama/Native/LLamaGrammarElement.cs

+ 95
- 0
LLama.Unittest/FixedSizeQueueTests.cs View File

@@ -0,0 +1,95 @@
using LLama.Common;

namespace LLama.Unittest;

public class FixedSizeQueueTests
{
[Fact]
public void Create()
{
var q = new FixedSizeQueue<int>(7);

Assert.Equal(7, q.Capacity);
Assert.Empty(q);
}

[Fact]
public void CreateFromItems()
{
var q = new FixedSizeQueue<int>(7, new [] { 1, 2, 3 });

Assert.Equal(7, q.Capacity);
Assert.Equal(3, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 }));
}

[Fact]
public void Indexing()
{
var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 });

Assert.Equal(1, q[0]);
Assert.Equal(2, q[1]);
Assert.Equal(3, q[2]);

Assert.Throws<ArgumentOutOfRangeException>(() => q[3]);
}

[Fact]
public void CreateFromFullItems()
{
var q = new FixedSizeQueue<int>(3, new[] { 1, 2, 3 });

Assert.Equal(3, q.Capacity);
Assert.Equal(3, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 }));
}

[Fact]
public void CreateFromTooManyItems()
{
Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, new[] { 1, 2, 3 }));
}

[Fact]
public void CreateFromTooManyItemsNonCountable()
{
Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, Items()));
return;

static IEnumerable<int> Items()
{
yield return 1;
yield return 2;
yield return 3;
}
}

[Fact]
public void Enqueue()
{
var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 });

q.Enqueue(4);
q.Enqueue(5);

Assert.Equal(7, q.Capacity);
Assert.Equal(5, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3, 4, 5 }));
}

[Fact]
public void EnqueueOverflow()
{
var q = new FixedSizeQueue<int>(5, new[] { 1, 2, 3 });

q.Enqueue(4);
q.Enqueue(5);
q.Enqueue(6);
q.Enqueue(7);

Assert.Equal(5, q.Capacity);
Assert.Equal(5, q.Count);
Assert.True(q.ToArray().SequenceEqual(new[] { 3, 4, 5, 6, 7 }));
}
}

+ 19
- 2
LLama.Unittest/GrammarTest.cs View File

@@ -40,6 +40,22 @@ namespace LLama.Unittest
using var handle = SafeLLamaGrammarHandle.Create(rules, 0);
}

[Fact]
public void CreateGrammar_StartIndexOutOfRange()
{
var rules = new List<GrammarRule>
{
new GrammarRule("alpha", new[]
{
new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'),
new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'),
new LLamaGrammarElement(LLamaGrammarElementType.END, 0),
}),
};

Assert.Throws<ArgumentOutOfRangeException>(() => new Grammar(rules, 3));
}

[Fact]
public async Task SampleWithTrivialGrammar()
{
@@ -56,14 +72,15 @@ namespace LLama.Unittest
}),
};

using var grammar = SafeLLamaGrammarHandle.Create(rules, 0);
var grammar = new Grammar(rules, 0);
using var grammarInstance = grammar.CreateInstance();

var executor = new StatelessExecutor(_model, _params);
var inferenceParams = new InferenceParams
{
MaxTokens = 3,
AntiPrompts = new [] { ".", "Input:", "\n" },
Grammar = grammar,
Grammar = grammarInstance,
};

var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync();


+ 1
- 21
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -2,7 +2,7 @@

namespace LLama.Unittest;

public class LLamaEmbedderTests
public sealed class LLamaEmbedderTests
: IDisposable
{
private readonly LLamaEmbedder _embedder;
@@ -37,26 +37,6 @@ public class LLamaEmbedderTests
return a.Zip(b, (x, y) => x * y).Sum();
}

private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f)
{
for (int i = 0; i < expected.Length; i++)
Assert.Equal(expected[i], actual[i], epsilon);
}

// todo: enable this one llama2 7B gguf is available
//[Fact]
//public void EmbedBasic()
//{
// var cat = _embedder.GetEmbeddings("cat");

// Assert.NotNull(cat);
// Assert.NotEmpty(cat);

// // Expected value generate with llama.cpp embedding.exe
// var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
// AssertApproxStartsWith(expected, cat);
//}

[Fact]
public void EmbedCompare()
{


+ 5
- 19
LLama/Common/FixedSizeQueue.cs View File

@@ -2,6 +2,7 @@
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using LLama.Extensions;

namespace LLama.Common
{
@@ -10,11 +11,12 @@ namespace LLama.Common
/// Currently it's only a naive implementation and needs to be further optimized in the future.
/// </summary>
public class FixedSizeQueue<T>
: IEnumerable<T>
: IReadOnlyList<T>
{
private readonly List<T> _storage;

internal IReadOnlyList<T> Items => _storage;
/// <inheritdoc />
public T this[int index] => _storage[index];

/// <summary>
/// Number of items in this queue
@@ -59,20 +61,6 @@ namespace LLama.Common
throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values.");
}

/// <summary>
/// Replace every item in the queue with the given value
/// </summary>
/// <param name="value">The value to replace all items with</param>
/// <returns>returns this</returns>
public FixedSizeQueue<T> FillWith(T value)
{
for(var i = 0; i < Count; i++)
{
_storage[i] = value;
}
return this;
}

/// <summary>
/// Enquene an element.
/// </summary>
@@ -80,10 +68,8 @@ namespace LLama.Common
public void Enqueue(T item)
{
_storage.Add(item);
if(_storage.Count >= Capacity)
{
if (_storage.Count > Capacity)
_storage.RemoveAt(0);
}
}

/// <inheritdoc />


+ 1
- 1
LLama/LLamaExecutorBase.cs View File

@@ -84,7 +84,7 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0);
_last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}



+ 1
- 1
LLama/LLamaInstructExecutor.cs View File

@@ -151,7 +151,7 @@ namespace LLama
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
{
args.WaitForInput = true;
return (true, Array.Empty<string>());


+ 1
- 1
LLama/LLamaInteractExecutor.cs View File

@@ -134,7 +134,7 @@ namespace LLama
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)


+ 2
- 35
LLama/Native/LLamaGrammarElement.cs View File

@@ -1,5 +1,4 @@
using System;
using System.Diagnostics;
using System.Diagnostics;
using System.Runtime.InteropServices;

namespace LLama.Native
@@ -52,8 +51,7 @@ namespace LLama.Native
/// </summary>
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
public struct LLamaGrammarElement
: IEquatable<LLamaGrammarElement>
public record struct LLamaGrammarElement
{
/// <summary>
/// The type of this element
@@ -76,37 +74,6 @@ namespace LLama.Native
Value = value;
}

/// <inheritdoc />
public bool Equals(LLamaGrammarElement other)
{
if (Type != other.Type)
return false;

// No need to compare values for the END rule
if (Type == LLamaGrammarElementType.END)
return true;

return Value == other.Value;
}

/// <inheritdoc />
public override bool Equals(object? obj)
{
return obj is LLamaGrammarElement other && Equals(other);
}

/// <inheritdoc />
public override int GetHashCode()
{
unchecked
{
var hash = 2999;
hash = hash * 7723 + (int)Type;
hash = hash * 7723 + (int)Value;
return hash;
}
}

internal bool IsCharElement()
{
switch (Type)


Loading…
Cancel
Save