Browse Source

refactor: drop LLamaModelV1.

tags/v0.2.3
Yaohui Liu 3 years ago
parent
commit
56c56b9c51
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
7 changed files with 1 additions and 1031 deletions
  1. +0
    -39
      LLama.Examples/ChatWithLLamaModelV1.cs
  2. +0
    -1
      LLama.Examples/Program.cs
  3. +1
    -5
      LLama.Unittest/BasicTest.cs
  4. +0
    -99
      LLama/LLamaCache.cs
  5. +0
    -836
      LLama/LLamaModelV1.cs
  6. +0
    -10
      LLama/LLamaState.cs
  7. +0
    -41
      LLama/LLamaTypes.cs

+ 0
- 39
LLama.Examples/ChatWithLLamaModelV1.cs View File

@@ -1,39 +0,0 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using LLama.Types;

namespace LLama.Examples
{
public class ChatWithLLamaModelV1
{
LLamaModelV1 _model;
public ChatWithLLamaModelV1(string modelPath)
{
_model = new(modelPath, logits_all: false, verbose: false, n_ctx: 512);
}

public void Run()
{
List<ChatCompletionMessage> chats = new List<ChatCompletionMessage>();
chats.Add(new ChatCompletionMessage(ChatRole.Human, "Hi, Alice, I'm Rinne."));
chats.Add(new ChatCompletionMessage(ChatRole.Assistant, "Hi, Rinne, I'm Alice, an assistant that answer any question. What can I do for you?"));
while (true)
{
Console.Write("\nYou: ");
Console.ForegroundColor = ConsoleColor.Green;
var question = Console.ReadLine();
Console.ForegroundColor = ConsoleColor.White;
chats.Add(new ChatCompletionMessage(ChatRole.Human, question));
var outputs = _model.CreateChatCompletion(chats, max_tokens: 256);
Console.Write($"LLama AI: ");
foreach (var output in outputs)
{
Console.Write($"{output.Choices[0].Delta.Content}");
}
}
}
}
}

+ 0
- 1
LLama.Examples/Program.cs View File

@@ -1,6 +1,5 @@
using LLama;
using LLama.Examples;
using LLama.Types;

Console.WriteLine("================LLamaSharp Examples==================\n");



+ 1
- 5
LLama.Unittest/BasicTest.cs View File

@@ -5,11 +5,7 @@ namespace LLama.Unittest
[Fact]
public void SimpleQA()
{
string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-f32.bin";
LLamaModelV1 model = new(modelPath, logits_all: false);
var output = model.Call("Q: Why God makes many people believe him? A: ", max_tokens: 64, stop: new[] { "Q:", "\n" },
echo: true);
Console.WriteLine(output);
}
}
}

+ 0
- 99
LLama/LLamaCache.cs View File

@@ -1,99 +0,0 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// Cache for a llama.cpp model.
/// </summary>
public class LLamaCache
{
private Dictionary<llama_token[], LinkedListNode<KeyValuePair<llama_token[], LLamaState>>> _cacheState;
private LinkedList<KeyValuePair<llama_token[], LLamaState>> _cacheList;
private int _capacity;

public int CacheSize
{
get
{
return _cacheState.Values.Select(s => s.Value.Value.Size).Sum();
}
}

/// <summary>
///
/// </summary>
/// <param name="capacity">The max capacity (bytes).</param>
public LLamaCache(int capacity = 2 << 30)
{
_cacheState = new();
_cacheList = new();
_capacity = capacity;
}

public LLamaState this[llama_token[] key]
{
get
{
var prefixKey = FindLongestPrefixKey(key);
if(prefixKey is null)
{
throw new KeyNotFoundException();
}
var value = _cacheState[prefixKey];
MoveNodeToEnd(prefixKey);
return value.Value.Value;
}
set
{
var node = _cacheList.AddLast(new KeyValuePair<llama_token[], LLamaState>(key, value));
_cacheState[key] = node;
while(CacheSize > _capacity && _cacheList.Count > 0)
{
var topop = _cacheList.First;
_cacheState.Remove(topop.Value.Key);
_cacheList.RemoveFirst();
}
}
}

public bool Contains(llama_token[] key)
{
return FindLongestPrefixKey(key) is not null;
}

private llama_token[]? FindLongestPrefixKey(llama_token[] key)
{
int minLen = 0;
llama_token[]? minKey = null;
var keys = _cacheState.Keys.Select(k => (k, LLamaModelV1.LongestTokenPrefix(k, key)));
foreach(var (k, prefixLen) in keys)
{
if(prefixLen > minLen)
{
minLen = prefixLen;
minKey = k;
}
}
return minKey;
}

private void MoveNodeToEnd(llama_token[] key)
{
if (!_cacheState.TryGetValue(key, out var node))
{
return;
}

_cacheState.Remove(key);
_cacheList.Remove(node);

var newNode = _cacheList.AddLast(new KeyValuePair<llama_token[], LLamaState>(key, node.Value.Value));
_cacheState.Add(key, newNode);
}
}
}

+ 0
- 836
LLama/LLamaModelV1.cs View File

@@ -1,836 +0,0 @@
using LLama.Exceptions;
using LLama.Native;
using System;
using System.Collections.Generic;
using System.Configuration;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using LLama.Types;
using System.Runtime.InteropServices;
using System.Text.RegularExpressions;
using System.Collections;

namespace LLama
{
using llama_token = Int32;
/// <summary>
/// High-level Wrapper of a llama.cpp model for inference. Note that it's more recommended to use `LLamaModel`.
/// This class may be removed in the future. However, if all you want is to get the embeddings, then using `LLamaModelV1`
/// is ok now.
/// </summary>
[Obsolete]
public class LLamaModelV1: IDisposable
{
private string _model_path;
LLamaContextParams _params;
private int _n_threads;
private int _n_batch;
private int _last_n_tokens_size;
private string? _lora_base;
private string? _lora_path;
private bool _verbose;

private Queue<llama_token> _eval_tokens;
private Queue<float[]> _eval_logits;
private LLamaCache? _cache;
private SafeLLamaContextHandle _ctx;

private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) };

/// <summary>
/// Load a llama.cpp model from the path.
/// </summary>
/// <remarks>Note that the API is still unstable. The order of them is likely to
/// be changed in the future. It's recommened to specify the parameter name when
/// building your app. We use the cpp style parameter names here because it introduces
/// convenience for searching the docs.</remarks>
/// <param name="model_path">Path to the model.</param>
/// <param name="n_ctx">Maximum context size.</param>
/// <param name="n_parts">Number of parts to split the model into. If -1, the number of parts is automatically determined.</param>
/// <param name="seed">Random seed. 0 for random.</param>
/// <param name="f16_kv">Use half-precision for key/value cache.</param>
/// <param name="logits_all">Return logits for all tokens, not just the last token.</param>
/// <param name="vocab_only">Only load the vocabulary no weights.</param>
/// <param name="use_mmap">Use mmap if possible.</param>
/// <param name="use_mlock">Force the system to keep the model in RAM.</param>
/// <param name="embedding">Embedding mode only.</param>
/// <param name="n_threads">Number of threads to use. If is not specified, the number of threads is automatically determined.</param>
/// <param name="n_batch">Maximum number of prompt tokens to batch together when calling llama_eval.</param>
/// <param name="last_n_tokens_size">Maximum number of tokens to keep in the last_n_tokens deque.</param>
/// <param name="lora_base">Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model.</param>
/// <param name="lora_path">Path to a LoRA file to apply to the model.</param>
/// <param name="verbose">Print verbose output to stderr.</param>
public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337,
bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true,
bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512,
int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true)
{
_verbose = verbose;
_model_path = model_path;

_params = NativeApi.llama_context_default_params();
_params.n_ctx = n_ctx;
_params.seed = seed;
_params.f16_kv = f16_kv;
_params.logits_all = logits_all;
_params.vocab_only = vocab_only;
_params.use_mmap = lora_path is null ? use_mmap : false;
_params.use_mlock = use_mlock;
_params.embedding = embedding;

_last_n_tokens_size = last_n_tokens_size;
_n_batch = Math.Min(n_ctx, n_batch);

_eval_tokens = new Queue<int>(capacity: n_ctx);
_eval_logits = new Queue<float[]>(logits_all ? n_ctx : 1);

_cache = null;

_n_threads = n_threads;
if(_n_threads == -1)
{
_n_threads = Math.Max(Environment.ProcessorCount / 2, 1);
}

_lora_base = lora_base;
_lora_path = lora_path;

if(!File.Exists(model_path) && !Directory.Exists(model_path))
{
throw new FileNotFoundException($"Model path does not exist: {model_path}");
}

// Move from heap to stack to prevent the moving.
_ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params));

Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);

if(_lora_path is not null)
{
if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0)
{
throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}");
}
}

if (_verbose)
{
Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info()));
}
}

public LLamaModelV1(LLamaModelV1 other)
{
_ctx = other._ctx;
_model_path = other._model_path;
_params = other._params;
_last_n_tokens_size = other._last_n_tokens_size;
_n_threads = other._n_threads;
_n_batch = other._n_batch;
_verbose = other._verbose;
_lora_base = other._lora_base;
_lora_path = other._lora_path;
_eval_logits = new Queue<float[]>(other._eval_logits);
_eval_tokens = new Queue<llama_token>(other._eval_tokens);
}

/// <summary>
/// Tokenize a string.
/// </summary>
/// <param name="text">The utf-8 encoded string to tokenize.</param>
/// <returns>A list of tokens.</returns>
/// <exception cref="RuntimeError">If the tokenization failed.</exception>
public List<llama_token> Tokenize(string text)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
var tokens = new llama_token[n_ctx];
var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true);
if(n_tokens < 0)
{
throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}");
}
return tokens.Take(n_tokens).ToList();
}

/// <summary>
/// Detokenize a list of tokens.
/// </summary>
/// <param name="tokens">The list of tokens to detokenize.</param>
/// <returns>The detokenized string.</returns>
public string DeTokenize(IEnumerable<llama_token> tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string output = "";
foreach(var token in tokens)
{
output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token));
}
return output;
}

public string DeTokenize(llama_token token)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? "";
}

/// <summary>
/// Set the cache.
/// </summary>
/// <param name="cache">The cache to set.</param>
public void SetCache(LLamaCache? cache)
{
_cache = cache;
}

/// <summary>
/// Reset the model state.
/// </summary>
public void Reset()
{
_eval_tokens.Clear();
_eval_logits.Clear();
}

/// <summary>
/// Evaluate a list of tokens.
/// </summary>
/// <param name="tokens">The list of tokens to evaluate.</param>
/// <exception cref="RuntimeError"></exception>
public unsafe void Eval(List<llama_token> tokens)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var n_ctx = NativeApi.llama_n_ctx(_ctx);
for(int i = 0; i < tokens.Count; i += _n_batch)
{
var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i);
llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count);
llama_token n_tokens = batch.Count();
llama_token return_code = NativeApi.llama_eval(
ctx: _ctx,
tokens: batch.ToArray(),
n_tokens: n_tokens,
n_past: n_past,
n_threads: _n_threads
);
if(return_code != 0)
{
throw new RuntimeError($"llama_eval returned {return_code}");
}
foreach(var b in batch)
{
_eval_tokens.Enqueue(b);
}
int rows = _params.logits_all ? n_tokens : 1;
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var cols = n_vocab;
var logits_view = NativeApi.llama_get_logits(_ctx);
for(int j = 0; j < rows; j++)
{
float[] logit = new float[cols];
for(int k = 0; k < cols; k++)
{
logit[k] = logits_view[j * cols + k];
}
_eval_logits.Enqueue(logit);
}
}
}

private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k,
float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
Debug.Assert(_eval_logits.Count > 0);
llama_token n_vocab = NativeApi.llama_n_vocab(_ctx);
var logits = _eval_logits.Last();
LLamaTokenData[] data = new LLamaTokenData[n_vocab];
for(int i = 0; i < n_vocab; i++)
{
data[i] = new LLamaTokenData(i, logits[i], .0f);
}
ulong size = (ulong)n_vocab;
bool sorted = false;
LLamaTokenDataArray candidates = new(data, size, sorted);
SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
repeat_penalty);
//SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size,
// frequency_penalty, presence_penalty);
if(temp == .0f)
{
return SamplingApi.llama_sample_token_greedy(_ctx, candidates);
}
else
{
SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1);
SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1);
SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1);
SamplingApi.llama_sample_temperature(_ctx, candidates, temp);
return SamplingApi.llama_sample_token(_ctx, candidates);
}
}

/// <summary>
/// Sample a token from the model.
/// </summary>
/// <param name="top_k">The top-k sampling parameter.</param>
/// <param name="top_p">The top-p sampling parameter.</param>
/// <param name="temp">The temperature parameter.</param>
/// <param name="repeat_penalty">The repeat penalty parameter.</param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <returns>The sampled token.</returns>
public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f,
float presence_penalty = .0f)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count));
last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList()
.Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size)));
llama_token[] tokens_data = new llama_token[_last_n_tokens_size];
int i = 0;
foreach(var data in last_n_tokens_data)
{
if(i < _last_n_tokens_size)
{
tokens_data[i++] = data;
}
else
{
break;
}
}
return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty);
}

/// <summary>
/// Create a generator of tokens from a prompt.
/// </summary>
/// <example>
/// Examples:
/// var llama = new LlamaModel("models/ggml-7b.bin")
/// var tokens = llama.Tokenize(b"Hello, world!")
/// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){
/// Console.WriteLine(llama.DeTokenize(new []{token}));
/// }
/// </example>
/// <param name="tokens"></param>
/// <param name="top_k"></param>
/// <param name="top_p"></param>
/// <param name="temp"></param>
/// <param name="repeat_penalty"></param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="reset"></param>
/// <returns></returns>
public IEnumerable<llama_token> Generate(IEnumerable<llama_token> tokens, int top_k, float top_p, float temp,
float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if(reset && _eval_tokens.Count > 0)
{
int longest_prefix = 0;
foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y)))
{
if(a == b)
{
longest_prefix += 1;
}
else
{
break;
}
}
if(longest_prefix > 0)
{
if (_verbose)
{
Logger.Default.Info("Llama.generate: prefix-match hit");
}
reset = false;
tokens = tokens.Skip(longest_prefix);
for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
{
_eval_tokens.Dequeue();
if(_eval_logits.Count > 0)
{
_eval_logits.Dequeue();
}
}
}
}

if (reset)
{
Reset();
}

while (true)
{
Eval(tokens.ToList());
var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty);
yield return token;
// TODO(Rinne): verify if the implementation is correct.
}
}

/// <summary>
/// Embed a string.
/// </summary>
/// <param name="input">The utf-8 encoded string to embed.</param>
/// <returns>An embedding object.</returns>
/// <exception cref="RuntimeError"></exception>
public unsafe Embedding CreateEmbedding(string input)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
if (!_params.embedding)
{
throw new RuntimeError("Llama model must be created with embedding=True to call this method");
}

if (_verbose)
{
NativeApi.llama_reset_timings(_ctx);
}

var tokens = Tokenize(input);
Reset();
Eval(tokens);
int n_tokens = tokens.Count;
var embeddingPtr = NativeApi.llama_get_embeddings(_ctx);
int cnt = NativeApi.llama_n_embd(_ctx);
float[] embedding = new float[cnt];
for(int i = 0; i < cnt; i++)
{
embedding[i] = embeddingPtr[i];
}

if (_verbose)
{
NativeApi.llama_print_timings(_ctx);
}

return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) },
new EmbeddingUsage(n_tokens, n_tokens));
}

public float[] Embed(string input)
{
return CreateEmbedding(input).Data[0].Embedding;
}

/// <summary>
///
/// </summary>
/// <param name="prompt"></param>
/// <param name="suffix"></param>
/// <param name="max_tokens"></param>
/// <param name="temperature"></param>
/// <param name="top_p"></param>
/// <param name="logprobs"></param>
/// <param name="echo"></param>
/// <param name="stop"></param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty"></param>
/// <param name="top_k"></param>
/// <returns>IEnumerable of Completion and CompletionChunk</returns>
/// <exception cref="ArgumentException"></exception>
private IEnumerable<CompletionChunk> CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
string completionId = $"cmpl-{Guid.NewGuid()}";
var created = DateTime.Now.Millisecond;
List<llama_token> completionTokens = new List<llama_token>();

var promptTokens = Tokenize($" {prompt}");
string text = "";
int returnedCharacters = 0;
if(stop is null)
{
stop = new string[0];
}

if (_verbose)
{
NativeApi.llama_reset_timings(_ctx);
}

if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx))
{
throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}");
}
if(logprobs != -1 && !_params.logits_all)
{
throw new ArgumentException("logprobs is not supported for models created with logits_all=False");
}

if(_cache is not null)
{
try
{
// TODO(Rinne): revise it since it will compare reference instead of elements.
var cacheItem = _cache[promptTokens.ToArray()];
var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens);
if(cachePrefixLen > evalPrefixLen)
{
LoadState(cacheItem);
if (_verbose)
{
Logger.Default.Info("Llama._create_completion: cache hit");
}
}
}
catch (KeyNotFoundException)
{
if (_verbose)
{
Logger.Default.Warn("Llama._create_completion: cache miss");
}
}
}

string finishReason = "length";
int multibyteFix = 0;
bool reset = true;
List<llama_token> tokens = new(promptTokens);
if (reset && _eval_tokens.Count > 0)
{
int longest_prefix = 0;
foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y)))
{
if (a == b)
{
longest_prefix += 1;
}
else
{
break;
}
}
if (longest_prefix > 0)
{
if (_verbose)
{
Logger.Default.Info("Llama.generate: prefix-match hit");
}
reset = false;
tokens = tokens.Skip(longest_prefix).ToList();
for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++)
{
_eval_tokens.Dequeue();
if (_eval_logits.Count > 0)
{
_eval_logits.Dequeue();
}
}
}
}

if (reset)
{
Reset();
}
//foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty))
string allText = "";
while (true)
{
Eval(tokens);
var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty);
tokens.Clear();
tokens.Add(token);
if (token == NativeApi.llama_token_eos())
{
text = DeTokenize(completionTokens);
finishReason = "stop";
break;
}

completionTokens.Add(token);

allText = DeTokenize(completionTokens);

int cut = Math.Min(3, allText.Length);
for(int i = allText.Length - cut; i < allText.Length; i++)
{
var c = (int)allText[i];
int k = cut - i;
foreach(var (num, pattern) in _numAndPatterns)
{
if(num > k && (pattern & c) == pattern)
{
multibyteFix = num - k;
}
}
}

if(multibyteFix > 0)
{
multibyteFix--;
continue;
}

var anyStop = stop.Where(s => allText.Contains(s));
if(anyStop.Count() > 0)
{
var firstStop = anyStop.First();
text = allText.Substring(0, allText.IndexOf(firstStop));
finishReason = "stop";
break;
}

var start = returnedCharacters;
int longest = 0;
foreach (var s in stop)
{
for (int i = s.Length; i > 0; i--)
{
if (allText.EndsWith(s.Substring(0, i)))
{
if (i > longest)
{
longest = i;
}
break;
}
}
}
text = allText.Substring(0, allText.Length - longest);
returnedCharacters += text.Skip(start).Count();
yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[]
{
new CompletionChoice(text.Substring(start), 0, null, finishReason)
});
}

if (_cache is not null)
{
if (_verbose)
{
Logger.Default.Info("Llama._create_completion: cache save");
}
_cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState();
}

string textStr = text;
if (echo)
{
textStr = prompt + textStr;
}
if(suffix is not null)
{
textStr = textStr + suffix;
}

CompletionLogprobs? logProbs = null;
if (logprobs != -1)
{
int textOffset = 0;
List<int> textOffsets = new();
List<float> tokenLogprobs = new();
List<string> tokenStrs = new();
List<Dictionary<string, float>> topLogprobs = new();

var allTokens = promptTokens.Concat(completionTokens).ToArray();
var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t }));
var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row));

foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y))
.Zip(allLogProbs, (x, y) => (x.x, x.y, y)))
{
textOffsets.Add(textOffset);
textOffset += tokenStr.Length;
tokenStrs.Add(tokenStr);
var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y))
.OrderByDescending(x => x.x).ToList();
tokenLogprobs.Add(sortedLogprobs[token].x);
var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x);
topLogprob[tokenStr] = sortedLogprobs[token].x;
topLogprobs.Add(topLogprob);
}

logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray());
}

if (_verbose)
{
NativeApi.llama_print_timings(_ctx);
}
}

/// <summary>
/// Generate text from a prompt and yield return the result.
/// </summary>
/// <param name="prompt">The prompt to generate text from.</param>
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
/// <param name="echo">Whether to echo the prompt.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public IEnumerable<CompletionChunk> CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}

/// <summary>
/// Generate text from a prompt and yield return the result.
/// </summary>
/// <param name="prompt">The prompt to generate text from.</param>
/// <param name="suffix">A suffix to append to the generated text. If None, no suffix is appended.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="logprobs">The number of logprobs to return. If None, no logprobs are returned.</param>
/// <param name="echo">Whether to echo the prompt.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="frequency_penalty"></param>
/// <param name="presence_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <returns></returns>
public IEnumerable<CompletionChunk> Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f,
float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f,
float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40)
{
return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop,
frequency_penalty, presence_penalty, repeat_penalty, top_k);
}

private ChatCompletion ConvertTextCompletionToChat(Completion completion)
{
return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model,
new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text),
completion.Choices[0].FinishReason) }, completion.Usage);
}

private IEnumerable<ChatCompletionChunk> ConvertTextCompletionChunksToChat(IEnumerable<CompletionChunk> chunks)
{
bool isFirst = true;
foreach(var chunk in chunks)
{
if(isFirst)
{
yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) });
isFirst = false;
}
yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created,
new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text),
chunk.Choices[0].FinishReason) });
}
}

/// <summary>
/// Generate a chat completion from a list of messages and yield return the result.
/// </summary>
/// <param name="messages">A list of messages to generate a response for.</param>
/// <param name="temperature">The temperature to use for sampling.</param>
/// <param name="top_p">The top-p value to use for sampling.</param>
/// <param name="top_k">The top-k value to use for sampling.</param>
/// <param name="stop">A list of strings to stop generation when encountered.</param>
/// <param name="max_tokens">The maximum number of tokens to generate.</param>
/// <param name="presence_penalty"></param>
/// <param name="frequency_penalty"></param>
/// <param name="repeat_penalty">The penalty to apply to repeated tokens.</param>
/// <returns></returns>
public IEnumerable<ChatCompletionChunk> CreateChatCompletion(IEnumerable<ChatCompletionMessage> messages, float temperature = .2f, float top_p = .95f,
int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f,
float repeat_penalty = 1.1f)
{
if (stop is null)
{
stop = new string[0];
}
string GetRole(ChatCompletionMessage message)
{
return message.Role == ChatRole.Human ? "Human" : "Assistant";
}
string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}"));
var prompt = chatHistory + "### Assistant:";
prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens));
var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray();
var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens,
repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty);
return ConvertTextCompletionChunksToChat(completion);
}

public LLamaState SaveState()
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
ulong stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] llamaState = new byte[stateSize];
ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState);
if(nBytes > stateSize)
{
throw new RuntimeError("Failed to copy llama state data");
}
byte[] llamaStateCompact = new byte[nBytes];
llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0);
if (_verbose)
{
Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state");
}
return new LLamaState(new Queue<llama_token>(_eval_tokens), new Queue<float[]>(_eval_logits),
llamaStateCompact, (int)nBytes);
}

public void LoadState(LLamaState state)
{
Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero);
_eval_tokens = new Queue<llama_token>(state.EvalTokens);
_eval_logits = new Queue<float[]>(state.EvalLogits);
if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size)
{
throw new RuntimeError($"Failed to set llama state data");
}
}

private static IEnumerable<float> LogitsToLogprobs(IEnumerable<float> logits)
{
var exps = logits.Select(x => (float)Math.Exp(x));
var sumExps = exps.Sum();
return exps.Select(x => (float)Math.Log(x / sumExps));
}

internal static int LongestTokenPrefix(IEnumerable<llama_token> a, IEnumerable<llama_token> b)
{
int longestPrefix = 0;
foreach(var (x, y) in a.Zip(b, (x, y) => (x, y)))
{
if(x == y)
{
longestPrefix++;
}
else
{
break;
}
}
return longestPrefix;
}

public void Dispose()
{
_ctx.Dispose();
}
}
}

+ 0
- 10
LLama/LLamaState.cs View File

@@ -1,10 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama
{
using llama_token = Int32;
public record LLamaState(Queue<llama_token> EvalTokens, Queue<float[]> EvalLogits,
byte[] State, int Size);
}

+ 0
- 41
LLama/LLamaTypes.cs View File

@@ -1,41 +0,0 @@
using System;
using System.Collections.Generic;
using System.Text;

namespace LLama.Types
{
public enum ChatRole
{
Human,
Assistant
}
public record EmbeddingUsage(int PromptTokens, int TotalTokens);

public record EmbeddingData(int Index, string Object, float[] Embedding);

public record Embedding(string Object, string Model, EmbeddingData[] Data, EmbeddingUsage Usage);

public record CompletionLogprobs(int[] TextOffset, float[] TokenLogProbs, string[] Tokens, Dictionary<string, float>[] TopLogprobs);

public record CompletionChoice(string Text, int Index, CompletionLogprobs? Logprobs, string? FinishReason);

public record CompletionUsage(int PromptTokens, int CompletionTokens, int TotalTokens);

public record CompletionChunk(string Id, string Object, int Created, string Model, CompletionChoice[] Choices);

public record Completion(string Id, string Object, int Created, string Model, CompletionChoice[] Choices, CompletionUsage Usage);

public record ChatCompletionMessage(ChatRole Role, string Content, string? Name = null);

public record ChatCompletionChoice(int Index, ChatCompletionMessage Message, string? FinishReason);

public record ChatCompletion(string Id, string Object, int Created, string Model, ChatCompletionChoice[] Choices, CompletionUsage Usage);

public record ChatCompletionChunkDelta(string? Role, string? Content);

public record ChatCompletionChunkChoice(int Index, ChatCompletionChunkDelta Delta, string? FinishReason);

public record ChatCompletionChunk(string Id, string Model, string Object, int Created, ChatCompletionChunkChoice[] Choices);

public record ChatMessageRecord(ChatCompletionMessage Message, DateTime Time);
}

Loading…
Cancel
Save