Improved test coverage.tags/v0.8.1
| @@ -0,0 +1,95 @@ | |||
| using LLama.Common; | |||
| namespace LLama.Unittest; | |||
| public class FixedSizeQueueTests | |||
| { | |||
| [Fact] | |||
| public void Create() | |||
| { | |||
| var q = new FixedSizeQueue<int>(7); | |||
| Assert.Equal(7, q.Capacity); | |||
| Assert.Empty(q); | |||
| } | |||
| [Fact] | |||
| public void CreateFromItems() | |||
| { | |||
| var q = new FixedSizeQueue<int>(7, new [] { 1, 2, 3 }); | |||
| Assert.Equal(7, q.Capacity); | |||
| Assert.Equal(3, q.Count); | |||
| Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 })); | |||
| } | |||
| [Fact] | |||
| public void Indexing() | |||
| { | |||
| var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 }); | |||
| Assert.Equal(1, q[0]); | |||
| Assert.Equal(2, q[1]); | |||
| Assert.Equal(3, q[2]); | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => q[3]); | |||
| } | |||
| [Fact] | |||
| public void CreateFromFullItems() | |||
| { | |||
| var q = new FixedSizeQueue<int>(3, new[] { 1, 2, 3 }); | |||
| Assert.Equal(3, q.Capacity); | |||
| Assert.Equal(3, q.Count); | |||
| Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3 })); | |||
| } | |||
| [Fact] | |||
| public void CreateFromTooManyItems() | |||
| { | |||
| Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, new[] { 1, 2, 3 })); | |||
| } | |||
| [Fact] | |||
| public void CreateFromTooManyItemsNonCountable() | |||
| { | |||
| Assert.Throws<ArgumentException>(() => new FixedSizeQueue<int>(2, Items())); | |||
| return; | |||
| static IEnumerable<int> Items() | |||
| { | |||
| yield return 1; | |||
| yield return 2; | |||
| yield return 3; | |||
| } | |||
| } | |||
| [Fact] | |||
| public void Enqueue() | |||
| { | |||
| var q = new FixedSizeQueue<int>(7, new[] { 1, 2, 3 }); | |||
| q.Enqueue(4); | |||
| q.Enqueue(5); | |||
| Assert.Equal(7, q.Capacity); | |||
| Assert.Equal(5, q.Count); | |||
| Assert.True(q.ToArray().SequenceEqual(new[] { 1, 2, 3, 4, 5 })); | |||
| } | |||
| [Fact] | |||
| public void EnqueueOverflow() | |||
| { | |||
| var q = new FixedSizeQueue<int>(5, new[] { 1, 2, 3 }); | |||
| q.Enqueue(4); | |||
| q.Enqueue(5); | |||
| q.Enqueue(6); | |||
| q.Enqueue(7); | |||
| Assert.Equal(5, q.Capacity); | |||
| Assert.Equal(5, q.Count); | |||
| Assert.True(q.ToArray().SequenceEqual(new[] { 3, 4, 5, 6, 7 })); | |||
| } | |||
| } | |||
| @@ -40,6 +40,22 @@ namespace LLama.Unittest | |||
| using var handle = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| } | |||
| [Fact] | |||
| public void CreateGrammar_StartIndexOutOfRange() | |||
| { | |||
| var rules = new List<GrammarRule> | |||
| { | |||
| new GrammarRule("alpha", new[] | |||
| { | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR, 'a'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.CHAR_RNG_UPPER, 'z'), | |||
| new LLamaGrammarElement(LLamaGrammarElementType.END, 0), | |||
| }), | |||
| }; | |||
| Assert.Throws<ArgumentOutOfRangeException>(() => new Grammar(rules, 3)); | |||
| } | |||
| [Fact] | |||
| public async Task SampleWithTrivialGrammar() | |||
| { | |||
| @@ -56,14 +72,15 @@ namespace LLama.Unittest | |||
| }), | |||
| }; | |||
| using var grammar = SafeLLamaGrammarHandle.Create(rules, 0); | |||
| var grammar = new Grammar(rules, 0); | |||
| using var grammarInstance = grammar.CreateInstance(); | |||
| var executor = new StatelessExecutor(_model, _params); | |||
| var inferenceParams = new InferenceParams | |||
| { | |||
| MaxTokens = 3, | |||
| AntiPrompts = new [] { ".", "Input:", "\n" }, | |||
| Grammar = grammar, | |||
| Grammar = grammarInstance, | |||
| }; | |||
| var result = await executor.InferAsync("Q. 7 + 12\nA. ", inferenceParams).ToListAsync(); | |||
| @@ -2,7 +2,7 @@ | |||
| namespace LLama.Unittest; | |||
| public class LLamaEmbedderTests | |||
| public sealed class LLamaEmbedderTests | |||
| : IDisposable | |||
| { | |||
| private readonly LLamaEmbedder _embedder; | |||
| @@ -37,26 +37,6 @@ public class LLamaEmbedderTests | |||
| return a.Zip(b, (x, y) => x * y).Sum(); | |||
| } | |||
| private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f) | |||
| { | |||
| for (int i = 0; i < expected.Length; i++) | |||
| Assert.Equal(expected[i], actual[i], epsilon); | |||
| } | |||
| // todo: enable this one llama2 7B gguf is available | |||
| //[Fact] | |||
| //public void EmbedBasic() | |||
| //{ | |||
| // var cat = _embedder.GetEmbeddings("cat"); | |||
| // Assert.NotNull(cat); | |||
| // Assert.NotEmpty(cat); | |||
| // // Expected value generate with llama.cpp embedding.exe | |||
| // var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f }; | |||
| // AssertApproxStartsWith(expected, cat); | |||
| //} | |||
| [Fact] | |||
| public void EmbedCompare() | |||
| { | |||
| @@ -2,6 +2,7 @@ | |||
| using System.Collections; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using LLama.Extensions; | |||
| namespace LLama.Common | |||
| { | |||
| @@ -10,11 +11,12 @@ namespace LLama.Common | |||
| /// Currently it's only a naive implementation and needs to be further optimized in the future. | |||
| /// </summary> | |||
| public class FixedSizeQueue<T> | |||
| : IEnumerable<T> | |||
| : IReadOnlyList<T> | |||
| { | |||
| private readonly List<T> _storage; | |||
| internal IReadOnlyList<T> Items => _storage; | |||
| /// <inheritdoc /> | |||
| public T this[int index] => _storage[index]; | |||
| /// <summary> | |||
| /// Number of items in this queue | |||
| @@ -59,20 +61,6 @@ namespace LLama.Common | |||
| throw new ArgumentException($"The max size set for the quene is {size}, but got {_storage.Count} initial values."); | |||
| } | |||
| /// <summary> | |||
| /// Replace every item in the queue with the given value | |||
| /// </summary> | |||
| /// <param name="value">The value to replace all items with</param> | |||
| /// <returns>returns this</returns> | |||
| public FixedSizeQueue<T> FillWith(T value) | |||
| { | |||
| for(var i = 0; i < Count; i++) | |||
| { | |||
| _storage[i] = value; | |||
| } | |||
| return this; | |||
| } | |||
| /// <summary> | |||
| /// Enquene an element. | |||
| /// </summary> | |||
| @@ -80,10 +68,8 @@ namespace LLama.Common | |||
| public void Enqueue(T item) | |||
| { | |||
| _storage.Add(item); | |||
| if(_storage.Count >= Capacity) | |||
| { | |||
| if (_storage.Count > Capacity) | |||
| _storage.RemoveAt(0); | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -84,7 +84,7 @@ namespace LLama | |||
| _pastTokensCount = 0; | |||
| _consumedTokensCount = 0; | |||
| _n_session_consumed = 0; | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0); | |||
| _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize); | |||
| _decoder = new StreamingTokenDecoder(context); | |||
| } | |||
| @@ -151,7 +151,7 @@ namespace LLama | |||
| { | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| { | |||
| args.WaitForInput = true; | |||
| return (true, Array.Empty<string>()); | |||
| @@ -134,7 +134,7 @@ namespace LLama | |||
| { | |||
| if (_embed_inps.Count <= _consumedTokensCount) | |||
| { | |||
| if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding)) | |||
| args.WaitForInput = true; | |||
| if (_pastTokensCount > 0 && args.WaitForInput) | |||
| @@ -1,5 +1,4 @@ | |||
| using System; | |||
| using System.Diagnostics; | |||
| using System.Diagnostics; | |||
| using System.Runtime.InteropServices; | |||
| namespace LLama.Native | |||
| @@ -52,8 +51,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| [StructLayout(LayoutKind.Sequential)] | |||
| [DebuggerDisplay("{Type} {Value}")] | |||
| public struct LLamaGrammarElement | |||
| : IEquatable<LLamaGrammarElement> | |||
| public record struct LLamaGrammarElement | |||
| { | |||
| /// <summary> | |||
| /// The type of this element | |||
| @@ -76,37 +74,6 @@ namespace LLama.Native | |||
| Value = value; | |||
| } | |||
| /// <inheritdoc /> | |||
| public bool Equals(LLamaGrammarElement other) | |||
| { | |||
| if (Type != other.Type) | |||
| return false; | |||
| // No need to compare values for the END rule | |||
| if (Type == LLamaGrammarElementType.END) | |||
| return true; | |||
| return Value == other.Value; | |||
| } | |||
| /// <inheritdoc /> | |||
| public override bool Equals(object? obj) | |||
| { | |||
| return obj is LLamaGrammarElement other && Equals(other); | |||
| } | |||
| /// <inheritdoc /> | |||
| public override int GetHashCode() | |||
| { | |||
| unchecked | |||
| { | |||
| var hash = 2999; | |||
| hash = hash * 7723 + (int)Type; | |||
| hash = hash * 7723 + (int)Value; | |||
| return hash; | |||
| } | |||
| } | |||
| internal bool IsCharElement() | |||
| { | |||
| switch (Type) | |||