diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 5a07c86c..2a88f25d 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -1,4 +1,3 @@ -using LLama; using LLama.Common; namespace LLama.Unittest diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs new file mode 100644 index 00000000..a34f58cb --- /dev/null +++ b/LLama.Unittest/LLamaContextTests.cs @@ -0,0 +1,36 @@ +using System.Text; +using LLama.Common; + +namespace LLama.Unittest +{ + public class LLamaContextTests + : IDisposable + { + private readonly LLamaWeights _weights; + private readonly LLamaContext _context; + + public LLamaContextTests() + { + var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin") + { + ContextSize = 768, + }; + _weights = LLamaWeights.LoadFromFile(@params); + _context = _weights.CreateContext(@params, Encoding.UTF8); + } + + public void Dispose() + { + _weights.Dispose(); + _context.Dispose(); + } + + [Fact] + public void CheckProperties() + { + Assert.Equal(768, _context.ContextSize); + Assert.Equal(4096, _context.EmbeddingSize); + Assert.Equal(32000, _context.VocabCount); + } + } +} diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs new file mode 100644 index 00000000..1c4b9fd7 --- /dev/null +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -0,0 +1,44 @@ +using LLama.Common; + +namespace LLama.Unittest +{ + public class LLamaEmbedderTests + : IDisposable + { + private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")); + + public void Dispose() + { + _embedder.Dispose(); + } + + private static float Dot(float[] a, float[] b) + { + Assert.Equal(a.Length, b.Length); + return a.Zip(b, (x, y) => x + y).Sum(); + } + + [Fact] + public void EmbedHello() + { + var hello = _embedder.GetEmbeddings("Hello"); + + Assert.NotNull(hello); + Assert.NotEmpty(hello); + Assert.Equal(_embedder.EmbeddingSize, hello.Length); + } + + [Fact] + public void EmbedCompare() + { + var cat = _embedder.GetEmbeddings("cat"); + var kitten = _embedder.GetEmbeddings("kitten"); + var spoon = _embedder.GetEmbeddings("spoon"); + + var close = Dot(cat, kitten); + var far = Dot(cat, spoon); + + Assert.True(close < far); + } + } +} diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index a74f11ee..6b82c4d8 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,9 +1,7 @@ using LLama.Native; using System; -using System.Collections.Generic; using System.Text; using LLama.Exceptions; -using System.Linq; using LLama.Abstractions; namespace LLama @@ -11,9 +9,15 @@ namespace LLama /// /// The embedder for LLama, which supports getting embeddings from text. /// - public class LLamaEmbedder : IDisposable + public class LLamaEmbedder + : IDisposable { - SafeLLamaContextHandle _ctx; + private readonly SafeLLamaContextHandle _ctx; + + /// + /// Dimension of embedding vectors + /// + public int EmbeddingSize => _ctx.EmbeddingSize; /// /// Warning: must ensure the original model has params.embedding = true;