- Modified `GetEmbeddings` method to be asynctags/v0.10.0
| @@ -1,14 +1,7 @@ | |||
| using LLama; | |||
| using LLama.Abstractions; | |||
| using LLama.Common; | |||
| using Microsoft.KernelMemory; | |||
| using Microsoft.KernelMemory.AI; | |||
| using Microsoft.SemanticKernel.AI.Embeddings; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Text; | |||
| using System.Threading.Tasks; | |||
| namespace LLamaSharp.KernelMemory | |||
| { | |||
| @@ -80,24 +73,24 @@ namespace LLamaSharp.KernelMemory | |||
| } | |||
| /// <inheritdoc/> | |||
| public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default) | |||
| public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default) | |||
| { | |||
| IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>(); | |||
| foreach (var d in data) | |||
| { | |||
| var embeddings = _embedder.GetEmbeddings(d); | |||
| var embeddings = await _embedder.GetEmbeddings(d, cancellationToken); | |||
| results.Add(new ReadOnlyMemory<float>(embeddings)); | |||
| } | |||
| return Task.FromResult(results); | |||
| return results; | |||
| } | |||
| /// <inheritdoc/> | |||
| public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) | |||
| public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default) | |||
| { | |||
| var embeddings = _embedder.GetEmbeddings(text); | |||
| return Task.FromResult(new Embedding(embeddings)); | |||
| var embeddings = await _embedder.GetEmbeddings(text, cancellationToken); | |||
| return new Embedding(embeddings); | |||
| } | |||
| /// <inheritdoc/> | |||
| @@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding; | |||
| public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService | |||
| { | |||
| private LLamaEmbedder _embedder; | |||
| private readonly LLamaEmbedder _embedder; | |||
| private readonly Dictionary<string, object?> _attributes = new(); | |||
| @@ -20,7 +20,11 @@ public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationServ | |||
| /// <inheritdoc/> | |||
| public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default) | |||
| { | |||
| var embeddings = data.Select(text => new ReadOnlyMemory<float>(_embedder.GetEmbeddings(text))).ToList(); | |||
| return await Task.FromResult(embeddings); | |||
| var result = new List<ReadOnlyMemory<float>>(); | |||
| foreach (var item in data) | |||
| result.Add(await _embedder.GetEmbeddings(item, cancellationToken)); | |||
| return result; | |||
| } | |||
| } | |||
| @@ -1,14 +1,17 @@ | |||
| using LLama.Common; | |||
| using Xunit.Abstractions; | |||
| namespace LLama.Unittest; | |||
| public sealed class LLamaEmbedderTests | |||
| : IDisposable | |||
| { | |||
| private readonly ITestOutputHelper _testOutputHelper; | |||
| private readonly LLamaEmbedder _embedder; | |||
| public LLamaEmbedderTests() | |||
| public LLamaEmbedderTests(ITestOutputHelper testOutputHelper) | |||
| { | |||
| _testOutputHelper = testOutputHelper; | |||
| var @params = new ModelParams(Constants.ModelPath) | |||
| { | |||
| EmbeddingMode = true, | |||
| @@ -41,21 +44,23 @@ public sealed class LLamaEmbedderTests | |||
| } | |||
| [Fact] | |||
| public void EmbedCompare() | |||
| public async Task EmbedCompare() | |||
| { | |||
| var cat = _embedder.GetEmbeddings("cat"); | |||
| var kitten = _embedder.GetEmbeddings("kitten"); | |||
| var spoon = _embedder.GetEmbeddings("spoon"); | |||
| var cat = await _embedder.GetEmbeddings("cat"); | |||
| var kitten = await _embedder.GetEmbeddings("kitten"); | |||
| var spoon = await _embedder.GetEmbeddings("spoon"); | |||
| Normalize(cat); | |||
| Normalize(kitten); | |||
| Normalize(spoon); | |||
| var close = Dot(cat, kitten); | |||
| var far = Dot(cat, spoon); | |||
| var close = 1 - Dot(cat, kitten); | |||
| var far = 1 - Dot(cat, spoon); | |||
| // This comparison seems backwards, but remember that with a | |||
| // dot product 1.0 means **identical** and 0.0 means **completely opposite**! | |||
| Assert.True(close > far); | |||
| Assert.True(close < far); | |||
| _testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]"); | |||
| _testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]"); | |||
| _testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]"); | |||
| } | |||
| } | |||
| @@ -3,6 +3,8 @@ using System; | |||
| using LLama.Exceptions; | |||
| using LLama.Abstractions; | |||
| using Microsoft.Extensions.Logging; | |||
| using System.Threading; | |||
| using System.Threading.Tasks; | |||
| namespace LLama | |||
| { | |||
| @@ -40,27 +42,12 @@ namespace LLama | |||
| /// Get the embeddings of the text. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="threads">unused</param> | |||
| /// <param name="addBos">Add bos to the text.</param> | |||
| /// <param name="encoding">unused</param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| [Obsolete("'threads' and 'encoding' parameters are no longer used")] | |||
| // ReSharper disable once MethodOverloadWithOptionalParameter | |||
| public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8") | |||
| public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default) | |||
| { | |||
| return GetEmbeddings(text, addBos); | |||
| } | |||
| /// <summary> | |||
| /// Get the embeddings of the text. | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public float[] GetEmbeddings(string text) | |||
| { | |||
| return GetEmbeddings(text, true); | |||
| return GetEmbeddings(text, true, cancellationToken); | |||
| } | |||
| /// <summary> | |||
| @@ -68,22 +55,48 @@ namespace LLama | |||
| /// </summary> | |||
| /// <param name="text"></param> | |||
| /// <param name="addBos">Add bos to the text.</param> | |||
| /// <param name="cancellationToken"></param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public float[] GetEmbeddings(string text, bool addBos) | |||
| public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default) | |||
| { | |||
| var embed_inp_array = Context.Tokenize(text, addBos); | |||
| var tokens = Context.Tokenize(text, addBos); | |||
| if (tokens.Length > Context.ContextSize) | |||
| throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text)); | |||
| // Evaluate prompt in batch-size chunks | |||
| var n_past = 0; | |||
| var batch = new LLamaBatch(); | |||
| var batchSize = (int)Context.Params.BatchSize; | |||
| for (var i = 0; i < tokens.Length; i += batchSize) | |||
| { | |||
| var n_eval = tokens.Length - i; | |||
| if (n_eval > batchSize) | |||
| n_eval = batchSize; | |||
| batch.Clear(); | |||
| for (var j = 0; j < n_eval; j++) | |||
| batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false); | |||
| var returnCode = await Context.DecodeAsync(batch, cancellationToken); | |||
| if (returnCode != 0) | |||
| throw new LLamaDecodeError(returnCode); | |||
| } | |||
| // TODO(Rinne): deal with log of prompt | |||
| var embeddings = GetEmbeddingsArray(); | |||
| if (embed_inp_array.Length > 0) | |||
| Context.Eval(embed_inp_array.AsSpan(), 0); | |||
| // Remove everything we just evaluated from the context cache | |||
| Context.NativeHandle.KvCacheClear(); | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null) | |||
| return Array.Empty<float>(); | |||
| return embeddings; | |||
| return embeddings.ToArray(); | |||
| float[] GetEmbeddingsArray() | |||
| { | |||
| var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); | |||
| if (embeddings == null) | |||
| return Array.Empty<float>(); | |||
| return embeddings.ToArray(); | |||
| } | |||
| } | |||
| /// <summary> | |||