diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs
index 5a07c86c..2a88f25d 100644
--- a/LLama.Unittest/BasicTest.cs
+++ b/LLama.Unittest/BasicTest.cs
@@ -1,4 +1,3 @@
-using LLama;
using LLama.Common;
namespace LLama.Unittest
diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
new file mode 100644
index 00000000..a34f58cb
--- /dev/null
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -0,0 +1,36 @@
+using System.Text;
+using LLama.Common;
+
+namespace LLama.Unittest
+{
+ public class LLamaContextTests
+ : IDisposable
+ {
+ private readonly LLamaWeights _weights;
+ private readonly LLamaContext _context;
+
+ public LLamaContextTests()
+ {
+ var @params = new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin")
+ {
+ ContextSize = 768,
+ };
+ _weights = LLamaWeights.LoadFromFile(@params);
+ _context = _weights.CreateContext(@params, Encoding.UTF8);
+ }
+
+ public void Dispose()
+ {
+ _weights.Dispose();
+ _context.Dispose();
+ }
+
+ [Fact]
+ public void CheckProperties()
+ {
+ Assert.Equal(768, _context.ContextSize);
+ Assert.Equal(4096, _context.EmbeddingSize);
+ Assert.Equal(32000, _context.VocabCount);
+ }
+ }
+}
diff --git a/LLama.Unittest/LLamaEmbedderTests.cs b/LLama.Unittest/LLamaEmbedderTests.cs
new file mode 100644
index 00000000..1c4b9fd7
--- /dev/null
+++ b/LLama.Unittest/LLamaEmbedderTests.cs
@@ -0,0 +1,44 @@
+using LLama.Common;
+
+namespace LLama.Unittest
+{
+ public class LLamaEmbedderTests
+ : IDisposable
+ {
+ private readonly LLamaEmbedder _embedder = new(new ModelParams("Models/llama-2-7b-chat.ggmlv3.q3_K_S.bin"));
+
+ public void Dispose()
+ {
+ _embedder.Dispose();
+ }
+
+ private static float Dot(float[] a, float[] b)
+ {
+ Assert.Equal(a.Length, b.Length);
+ return a.Zip(b, (x, y) => x + y).Sum();
+ }
+
+ [Fact]
+ public void EmbedHello()
+ {
+ var hello = _embedder.GetEmbeddings("Hello");
+
+ Assert.NotNull(hello);
+ Assert.NotEmpty(hello);
+ Assert.Equal(_embedder.EmbeddingSize, hello.Length);
+ }
+
+ [Fact]
+ public void EmbedCompare()
+ {
+ var cat = _embedder.GetEmbeddings("cat");
+ var kitten = _embedder.GetEmbeddings("kitten");
+ var spoon = _embedder.GetEmbeddings("spoon");
+
+ var close = Dot(cat, kitten);
+ var far = Dot(cat, spoon);
+
+ Assert.True(close < far);
+ }
+ }
+}
diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs
index a74f11ee..6b82c4d8 100644
--- a/LLama/LLamaEmbedder.cs
+++ b/LLama/LLamaEmbedder.cs
@@ -1,9 +1,7 @@
using LLama.Native;
using System;
-using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
-using System.Linq;
using LLama.Abstractions;
namespace LLama
@@ -11,9 +9,15 @@ namespace LLama
///
/// The embedder for LLama, which supports getting embeddings from text.
///
- public class LLamaEmbedder : IDisposable
+ public class LLamaEmbedder
+ : IDisposable
{
- SafeLLamaContextHandle _ctx;
+ private readonly SafeLLamaContextHandle _ctx;
+
+ ///
+ /// Dimension of embedding vectors
+ ///
+ public int EmbeddingSize => _ctx.EmbeddingSize;
///
/// Warning: must ensure the original model has params.embedding = true;