Browse Source

- Swapped embeddings generator to use `llama_decode`

- Modified `GetEmbeddings` method to be async
tags/v0.10.0
Martin Evans 2 years ago
parent
commit
c9c8cd0d62
4 changed files with 68 additions and 53 deletions
  1. +6
    -13
      LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs
  2. +7
    -3
      LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs
  3. +15
    -10
      LLama.Unittest/LLamaEmbedderTests.cs
  4. +40
    -27
      LLama/LLamaEmbedder.cs

+ 6
- 13
LLama.KernelMemory/LLamaSharpTextEmbeddingGenerator.cs View File

@@ -1,14 +1,7 @@
using LLama;
using LLama.Abstractions;
using LLama.Common;
using Microsoft.KernelMemory;
using Microsoft.KernelMemory.AI;
using Microsoft.SemanticKernel.AI.Embeddings;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace LLamaSharp.KernelMemory
{
@@ -80,24 +73,24 @@ namespace LLamaSharp.KernelMemory
}

/// <inheritdoc/>
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
{
IList<ReadOnlyMemory<float>> results = new List<ReadOnlyMemory<float>>();

foreach (var d in data)
{
var embeddings = _embedder.GetEmbeddings(d);
var embeddings = await _embedder.GetEmbeddings(d, cancellationToken);
results.Add(new ReadOnlyMemory<float>(embeddings));
}

return Task.FromResult(results);
return results;
}

/// <inheritdoc/>
public Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
public async Task<Embedding> GenerateEmbeddingAsync(string text, CancellationToken cancellationToken = default)
{
var embeddings = _embedder.GetEmbeddings(text);
return Task.FromResult(new Embedding(embeddings));
var embeddings = await _embedder.GetEmbeddings(text, cancellationToken);
return new Embedding(embeddings);
}

/// <inheritdoc/>


+ 7
- 3
LLama.SemanticKernel/TextEmbedding/LLamaSharpEmbeddingGeneration.cs View File

@@ -6,7 +6,7 @@ namespace LLamaSharp.SemanticKernel.TextEmbedding;

public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationService
{
private LLamaEmbedder _embedder;
private readonly LLamaEmbedder _embedder;

private readonly Dictionary<string, object?> _attributes = new();

@@ -20,7 +20,11 @@ public sealed class LLamaSharpEmbeddingGeneration : ITextEmbeddingGenerationServ
/// <inheritdoc/>
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
var embeddings = data.Select(text => new ReadOnlyMemory<float>(_embedder.GetEmbeddings(text))).ToList();
return await Task.FromResult(embeddings);
var result = new List<ReadOnlyMemory<float>>();

foreach (var item in data)
result.Add(await _embedder.GetEmbeddings(item, cancellationToken));

return result;
}
}

+ 15
- 10
LLama.Unittest/LLamaEmbedderTests.cs View File

@@ -1,14 +1,17 @@
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class LLamaEmbedderTests
: IDisposable
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaEmbedder _embedder;

public LLamaEmbedderTests()
public LLamaEmbedderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
var @params = new ModelParams(Constants.ModelPath)
{
EmbeddingMode = true,
@@ -41,21 +44,23 @@ public sealed class LLamaEmbedderTests
}

[Fact]
public void EmbedCompare()
public async Task EmbedCompare()
{
var cat = _embedder.GetEmbeddings("cat");
var kitten = _embedder.GetEmbeddings("kitten");
var spoon = _embedder.GetEmbeddings("spoon");
var cat = await _embedder.GetEmbeddings("cat");
var kitten = await _embedder.GetEmbeddings("kitten");
var spoon = await _embedder.GetEmbeddings("spoon");

Normalize(cat);
Normalize(kitten);
Normalize(spoon);

var close = Dot(cat, kitten);
var far = Dot(cat, spoon);
var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

// This comparison seems backwards, but remember that with a
// dot product 1.0 means **identical** and 0.0 means **completely opposite**!
Assert.True(close > far);
Assert.True(close < far);

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");
}
}

+ 40
- 27
LLama/LLamaEmbedder.cs View File

@@ -3,6 +3,8 @@ using System;
using LLama.Exceptions;
using LLama.Abstractions;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Threading.Tasks;

namespace LLama
{
@@ -40,27 +42,12 @@ namespace LLama
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <param name="threads">unused</param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="encoding">unused</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("'threads' and 'encoding' parameters are no longer used")]
// ReSharper disable once MethodOverloadWithOptionalParameter
public float[] GetEmbeddings(string text, int threads = -1, bool addBos = true, string encoding = "UTF-8")
public Task<float[]> GetEmbeddings(string text, CancellationToken cancellationToken = default)
{
return GetEmbeddings(text, addBos);
}

/// <summary>
/// Get the embeddings of the text.
/// </summary>
/// <param name="text"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text)
{
return GetEmbeddings(text, true);
return GetEmbeddings(text, true, cancellationToken);
}

/// <summary>
@@ -68,22 +55,48 @@ namespace LLama
/// </summary>
/// <param name="text"></param>
/// <param name="addBos">Add bos to the text.</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public float[] GetEmbeddings(string text, bool addBos)
public async Task<float[]> GetEmbeddings(string text, bool addBos, CancellationToken cancellationToken = default)
{
var embed_inp_array = Context.Tokenize(text, addBos);
var tokens = Context.Tokenize(text, addBos);
if (tokens.Length > Context.ContextSize)
throw new ArgumentException($"Embedding prompt is longer than the context window ({tokens.Length} > {Context.ContextSize})", nameof(text));

// Evaluate prompt in batch-size chunks
var n_past = 0;
var batch = new LLamaBatch();
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Length; i += batchSize)
{
var n_eval = tokens.Length - i;
if (n_eval > batchSize)
n_eval = batchSize;

batch.Clear();
for (var j = 0; j < n_eval; j++)
batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, false);

var returnCode = await Context.DecodeAsync(batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// TODO(Rinne): deal with log of prompt
var embeddings = GetEmbeddingsArray();

if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array.AsSpan(), 0);
// Remove everything we just evaluated from the context cache
Context.NativeHandle.KvCacheClear();

var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings;

return embeddings.ToArray();
float[] GetEmbeddingsArray()
{
var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
return Array.Empty<float>();
return embeddings.ToArray();
}
}

/// <summary>


Loading…
Cancel
Save