diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
new file mode 100644
index 00000000..03487353
--- /dev/null
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -0,0 +1,70 @@
+using LLama.Common;
+
+namespace LLama.Unittest;
+
+public class LLamaEmbedderTests
+ : IDisposable
+{
+ private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));
+
+ public void Dispose()
+ {
+ _embedder.Dispose();
+ }
+
+ private static float Magnitude(float[] a)
+ {
+ return MathF.Sqrt(a.Zip(a, (x, y) => x * y).Sum());
+ }
+
+ private static void Normalize(float[] a)
+ {
+ var mag = Magnitude(a);
+ for (var i = 0; i < a.Length; i++)
+ a[i] /= mag;
+ }
+
+ private static float Dot(float[] a, float[] b)
+ {
+ Assert.Equal(a.Length, b.Length);
+ return a.Zip(b, (x, y) => x * y).Sum();
+ }
+
+ private static void AssertApproxStartsWith(float[] expected, float[] actual, float epsilon = 0.08f)
+ {
+ for (int i = 0; i < expected.Length; i++)
+ Assert.Equal(expected[i], actual[i], epsilon);
+ }
+
+ [Fact]
+ public void EmbedBasic()
+ {
+ var cat = _embedder.GetEmbeddings("cat");
+
+ Assert.NotNull(cat);
+ Assert.NotEmpty(cat);
+
+ // Expected value generate with llama.cpp embedding.exe
+ var expected = new float[] { -0.127304f, -0.678057f, -0.085244f, -0.956915f, -0.638633f };
+ AssertApproxStartsWith(expected, cat);
+ }
+
+ [Fact]
+ public void EmbedCompare()
+ {
+ var cat = _embedder.GetEmbeddings("cat");
+ var kitten = _embedder.GetEmbeddings("kitten");
+ var spoon = _embedder.GetEmbeddings("spoon");
+
+ Normalize(cat);
+ Normalize(kitten);
+ Normalize(spoon);
+
+ var close = Dot(cat, kitten);
+ var far = Dot(cat, spoon);
+
+ // This comparison seems backwards, but remember that with a
+ // dot product 1.0 means **identical** and 0.0 means **completely opposite**!
+ Assert.True(close > far);
+ }
+}
\ No newline at end of file
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index 5acf756b..57c305b2 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -1,6 +1,5 @@
using LLama.Native;
using System;
-using System.Text;
using LLama.Exceptions;
using LLama.Abstractions;
@@ -12,22 +11,13 @@ namespace LLama
public class LLamaEmbedder
: IDisposable
{
- private readonly SafeLLamaContextHandle _ctx;
+ private readonly LLamaContext _ctx;
///
/// Dimension of embedding vectors
///
public int EmbeddingSize => _ctx.EmbeddingSize;
- ///
- /// Warning: must ensure the original model has params.embedding = true;
- ///
- ///
- internal LLamaEmbedder(SafeLLamaContextHandle ctx)
- {
- _ctx = ctx;
- }
-
///
///
///
@@ -35,52 +25,66 @@ namespace LLama
public LLamaEmbedder(IModelParams @params)
{
@params.EmbeddingMode = true;
- _ctx = Utils.InitLLamaContextFromModelParams(@params);
+ using var weights = LLamaWeights.LoadFromFile(@params);
+ _ctx = weights.CreateContext(@params);
}
///
/// Get the embeddings of the text.
///
///
- /// Threads used for inference.
+ /// unused
/// Add bos to the text.
- ///
+ /// unused
///
///
- public unsafe float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
+ [Obsolete("'threads' and 'encoding' parameters are no longer used")]
+ // ReSharper disable once MethodOverloadWithOptionalParameter
+ public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
{
- if (threads == -1)
- {
- threads = Math.Max(Environment.ProcessorCount / 2, 1);
- }
+ return GetEmbeddings(text, addBos);
+ }
+
+ ///
+ /// Get the embeddings of the text.
+ ///
+ ///
+ ///
+ ///
+ public float[] GetEmbeddings(string text)
+ {
+ return GetEmbeddings(text, true);
+ }
+ ///
+ /// Get the embeddings of the text.
+ ///
+ ///
+ /// Add bos to the text.
+ ///
+ ///
+ public float[] GetEmbeddings(string text, bool addBos)
+ {
if (addBos)
{
text = text.Insert(0, " ");
}
- var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding));
+ var embed_inp_array = _ctx.Tokenize(text, addBos);
// TODO(Rinne): deal with log of prompt
if (embed_inp_array.Length > 0)
- {
- if (NativeApi.llama_eval(_ctx, embed_inp_array, embed_inp_array.Length, 0, threads) != 0)
- {
- throw new RuntimeError("Failed to eval.");
- }
- }
+ _ctx.Eval(embed_inp_array, 0);
- int n_embed = NativeApi.llama_n_embd(_ctx);
- var embeddings = NativeApi.llama_get_embeddings(_ctx);
- if (embeddings == null)
+ unsafe
{
- return Array.Empty();
+ var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle);
+ if (embeddings == null)
+ return Array.Empty();
+
+ return new Span(embeddings, EmbeddingSize).ToArray();
}
- var span = new Span(embeddings, n_embed);
- float[] res = new float[n_embed];
- span.CopyTo(res.AsSpan());
- return res;
}
///