diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs new file mode 100644 index 00000000..03487353 --- /dev/null +++ b/LLama.Unittest/LLamaEmbedderTests.cs @@ -0,0 +1,70 @@ +using LLama.Common; + +namespace LLama.Unittest; + +public class LLamaEmbedderTests + : IDisposable +{ + private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")); + + public void Dispose() + { + _embedder.Dispose(); + } + + private static float Magnitude(float[] a) + { + return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum()); + } + + private static void Normalize(float[] a) + { + var mag = Magnitude(a); + for (var i = 0; i < a.Length; i++) + a[i] /= mag; + } + + private static float Dot(float[] a, float[] b) + { + Assert.Equal(a.Length, b.Length); + return a.Zip(b, (x, y) => x * y).Sum(); + } + + private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f) + { + for (int i = 0; i < expected.Length; i++) + Assert.Equal(expected[i], actual[i], epsilon); + } + + [Fact] + public void EmbedBasic() + { + var cat = _embedder.GetEmbeddings("cat"); + + Assert.NotNull(cat); + Assert.NotEmpty(cat); + + // Expected value generate with llama.cpp embedding.exe + var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f }; + AssertApproxStartsWith(expected, cat); + } + + [Fact] + public void EmbedCompare() + { + var cat = _embedder.GetEmbeddings("cat"); + var kitten = _embedder.GetEmbeddings("kitten"); + var spoon = _embedder.GetEmbeddings("spoon"); + + Normalize(cat); + Normalize(kitten); + Normalize(spoon); + + var close = Dot(cat, kitten); + var far = Dot(cat, spoon); + + // This comparison seems backwards, but remember that with a + // dot product 1.0 means **identical** and 0.0 means **completely opposite**! + Assert.True(close > far); + } +} \ No newline at end of file diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 5acf756b..57c305b2 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -1,6 +1,5 @@ using LLama.Native; using System; -using System.Text; using LLama.Exceptions; using LLama.Abstractions; @@ -12,22 +11,13 @@ namespace LLama public class LLamaEmbedder : IDisposable { - private readonly SafeLLamaContextHandle _ctx; + private readonly LLamaContext _ctx; /// /// Dimension of embedding vectors /// public int EmbeddingSize => _ctx.EmbeddingSize; - /// - /// Warning: must ensure the original model has params.embedding = true; - /// - /// - internal LLamaEmbedder(SafeLLamaContextHandle ctx) - { - _ctx = ctx; - } - /// /// /// @@ -35,52 +25,66 @@ namespace LLama public LLamaEmbedder(IModelParams @params) { @params.EmbeddingMode = true; - _ctx = Utils.InitLLamaContextFromModelParams(@params); + using var weights = LLamaWeights.LoadFromFile(@params); + _ctx = weights.CreateContext(@params); } /// /// Get the embeddings of the text. /// /// - /// Threads used for inference. + /// unused /// Add bos to the text. - /// + /// unused /// /// - public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") + [Obsolete("'threads' and 'encoding' parameters are no longer used")] + // ReSharper disable once MethodOverloadWithOptionalParameter + public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") { - if (threads == -1) - { - threads = Math.Max(Environment.ProcessorCount / 2, 1); - } + return GetEmbeddings(text, addBos); + } + + /// + /// Get the embeddings of the text. + /// + /// + /// + /// + public float[] GetEmbeddings(string text) + { + return GetEmbeddings(text, true); + } + /// + /// Get the embeddings of the text. + /// + /// + /// Add bos to the text. + /// + /// + public float[] GetEmbeddings(string text, bool addBos) + { if (addBos) { text = text.Insert(0, " "); } - var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); + var embed_inp_array = _ctx.Tokenize(text, addBos); // TODO(Rinne): deal with log of prompt if (embed_inp_array.Length > 0) - { - if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0) - { - throw new RuntimeError("Failed to eval."); - } - } + _ctx.Eval(embed_inp_array, 0); - int n_embed = NativeApi.llama_n_embd(_ctx); - var embeddings = NativeApi.llama_get_embeddings(_ctx); - if (embeddings == null) + unsafe { - return Array.Empty(); + var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); + if (embeddings == null) + return Array.Empty(); + + return new Span(embeddings, EmbeddingSize).ToArray(); } - var span = new Span(embeddings, n_embed); - float[] res = new float[n_embed]; - span.CopyTo(res.AsSpan()); - return res; } ///