diff --git a/LLama.Examples/ChatWithLLamaModelV1.cs b/LLama.Examples/ChatWithLLamaModelV1.cs deleted file mode 100644 index d8b45a05..00000000 --- a/LLama.Examples/ChatWithLLamaModelV1.cs +++ /dev/null @@ -1,39 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using LLama.Types; - -namespace LLama.Examples -{ - public class ChatWithLLamaModelV1 - { - LLamaModelV1 _model; - public ChatWithLLamaModelV1(string modelPath) - { - _model = new(modelPath, logits_all: false, verbose: false, n_ctx: 512); - } - - public void Run() - { - List chats = new List(); - chats.Add(new ChatCompletionMessage(ChatRole.Human, "Hi, Alice, I'm Rinne.")); - chats.Add(new ChatCompletionMessage(ChatRole.Assistant, "Hi, Rinne, I'm Alice, an assistant that answer any question. What can I do for you?")); - while (true) - { - Console.Write("\nYou: "); - Console.ForegroundColor = ConsoleColor.Green; - var question = Console.ReadLine(); - Console.ForegroundColor = ConsoleColor.White; - chats.Add(new ChatCompletionMessage(ChatRole.Human, question)); - var outputs = _model.CreateChatCompletion(chats, max_tokens: 256); - Console.Write($"LLama AI: "); - foreach (var output in outputs) - { - Console.Write($"{output.Choices[0].Delta.Content}"); - } - } - } - } -} diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index 97b0e3f3..106c4eaa 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -1,6 +1,5 @@ using LLama; using LLama.Examples; -using LLama.Types; Console.WriteLine("================LLamaSharp Examples==================\n"); diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index dac9caa3..29178432 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -5,11 +5,7 @@ namespace LLama.Unittest [Fact] public void SimpleQA() { - string modelPath = @"D:\development\llama\weights\LLaMA\7B\ggml-model-f32.bin"; - LLamaModelV1 model = new(modelPath, logits_all: false); - var output = model.Call("Q: Why God makes many people believe him? A: ", max_tokens: 64, stop: new[] { "Q:", "\n" }, - echo: true); - Console.WriteLine(output); + } } } \ No newline at end of file diff --git a/LLama/LLamaCache.cs b/LLama/LLamaCache.cs deleted file mode 100644 index a13931ae..00000000 --- a/LLama/LLamaCache.cs +++ /dev/null @@ -1,99 +0,0 @@ -using System; -using System.Collections; -using System.Collections.Generic; -using System.Linq; -using System.Text; - -namespace LLama -{ - using llama_token = Int32; - /// - /// Cache for a llama.cpp model. - /// - public class LLamaCache - { - private Dictionary>> _cacheState; - private LinkedList> _cacheList; - private int _capacity; - - public int CacheSize - { - get - { - return _cacheState.Values.Select(s => s.Value.Value.Size).Sum(); - } - } - - /// - /// - /// - /// The max capacity (bytes). - public LLamaCache(int capacity = 2 << 30) - { - _cacheState = new(); - _cacheList = new(); - _capacity = capacity; - } - - public LLamaState this[llama_token[] key] - { - get - { - var prefixKey = FindLongestPrefixKey(key); - if(prefixKey is null) - { - throw new KeyNotFoundException(); - } - var value = _cacheState[prefixKey]; - MoveNodeToEnd(prefixKey); - return value.Value.Value; - } - set - { - var node = _cacheList.AddLast(new KeyValuePair(key, value)); - _cacheState[key] = node; - while(CacheSize > _capacity && _cacheList.Count > 0) - { - var topop = _cacheList.First; - _cacheState.Remove(topop.Value.Key); - _cacheList.RemoveFirst(); - } - } - } - - public bool Contains(llama_token[] key) - { - return FindLongestPrefixKey(key) is not null; - } - - private llama_token[]? FindLongestPrefixKey(llama_token[] key) - { - int minLen = 0; - llama_token[]? minKey = null; - var keys = _cacheState.Keys.Select(k => (k, LLamaModelV1.LongestTokenPrefix(k, key))); - foreach(var (k, prefixLen) in keys) - { - if(prefixLen > minLen) - { - minLen = prefixLen; - minKey = k; - } - } - return minKey; - } - - private void MoveNodeToEnd(llama_token[] key) - { - if (!_cacheState.TryGetValue(key, out var node)) - { - return; - } - - _cacheState.Remove(key); - _cacheList.Remove(node); - - var newNode = _cacheList.AddLast(new KeyValuePair(key, node.Value.Value)); - _cacheState.Add(key, newNode); - } - } -} diff --git a/LLama/LLamaModelV1.cs b/LLama/LLamaModelV1.cs deleted file mode 100644 index b9e5b822..00000000 --- a/LLama/LLamaModelV1.cs +++ /dev/null @@ -1,836 +0,0 @@ -using LLama.Exceptions; -using LLama.Native; -using System; -using System.Collections.Generic; -using System.Configuration; -using System.Diagnostics; -using System.IO; -using System.Linq; -using System.Runtime.CompilerServices; -using System.Text; -using LLama.Types; -using System.Runtime.InteropServices; -using System.Text.RegularExpressions; -using System.Collections; - -namespace LLama -{ - using llama_token = Int32; - /// - /// High-level Wrapper of a llama.cpp model for inference. Note that it's more recommended to use `LLamaModel`. - /// This class may be removed in the future. However, if all you want is to get the embeddings, then using `LLamaModelV1` - /// is ok now. - /// - [Obsolete] - public class LLamaModelV1: IDisposable - { - private string _model_path; - LLamaContextParams _params; - private int _n_threads; - private int _n_batch; - private int _last_n_tokens_size; - private string? _lora_base; - private string? _lora_path; - private bool _verbose; - - private Queue _eval_tokens; - private Queue _eval_logits; - private LLamaCache? _cache; - private SafeLLamaContextHandle _ctx; - - private static readonly (int, int)[] _numAndPatterns = new (int, int)[] { (2, 192), (3, 224), (4, 240) }; - - /// - /// Load a llama.cpp model from the path. - /// - /// Note that the API is still unstable. The order of them is likely to - /// be changed in the future. It's recommened to specify the parameter name when - /// building your app. We use the cpp style parameter names here because it introduces - /// convenience for searching the docs. - /// Path to the model. - /// Maximum context size. - /// Number of parts to split the model into. If -1, the number of parts is automatically determined. - /// Random seed. 0 for random. - /// Use half-precision for key/value cache. - /// Return logits for all tokens, not just the last token. - /// Only load the vocabulary no weights. - /// Use mmap if possible. - /// Force the system to keep the model in RAM. - /// Embedding mode only. - /// Number of threads to use. If is not specified, the number of threads is automatically determined. - /// Maximum number of prompt tokens to batch together when calling llama_eval. - /// Maximum number of tokens to keep in the last_n_tokens deque. - /// Optional path to base model, useful if using a quantized base model and you want to apply LoRA to an f16 model. - /// Path to a LoRA file to apply to the model. - /// Print verbose output to stderr. - public LLamaModelV1(string model_path, int n_ctx = 512, int n_parts = -1, int seed = 1337, - bool f16_kv = true, bool logits_all = false, bool vocab_only = false, bool use_mmap = true, - bool use_mlock = false, bool embedding = false, int n_threads = -1, int n_batch = 512, - int last_n_tokens_size = 64, string? lora_base = null, string? lora_path = null, bool verbose = true) - { - _verbose = verbose; - _model_path = model_path; - - _params = NativeApi.llama_context_default_params(); - _params.n_ctx = n_ctx; - _params.seed = seed; - _params.f16_kv = f16_kv; - _params.logits_all = logits_all; - _params.vocab_only = vocab_only; - _params.use_mmap = lora_path is null ? use_mmap : false; - _params.use_mlock = use_mlock; - _params.embedding = embedding; - - _last_n_tokens_size = last_n_tokens_size; - _n_batch = Math.Min(n_ctx, n_batch); - - _eval_tokens = new Queue(capacity: n_ctx); - _eval_logits = new Queue(logits_all ? n_ctx : 1); - - _cache = null; - - _n_threads = n_threads; - if(_n_threads == -1) - { - _n_threads = Math.Max(Environment.ProcessorCount / 2, 1); - } - - _lora_base = lora_base; - _lora_path = lora_path; - - if(!File.Exists(model_path) && !Directory.Exists(model_path)) - { - throw new FileNotFoundException($"Model path does not exist: {model_path}"); - } - - // Move from heap to stack to prevent the moving. - _ctx = new SafeLLamaContextHandle(NativeApi.llama_init_from_file(Encoding.UTF8.GetString(Encoding.UTF8.GetBytes(model_path)), _params)); - - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - - if(_lora_path is not null) - { - if(NativeApi.llama_apply_lora_from_file(_ctx, lora_path, lora_base, _n_threads) != 0) - { - throw new RuntimeError($"Failed to apply LoRA from lora path: {_lora_path} to base path: {_lora_base}"); - } - } - - if (_verbose) - { - Logger.Default.Info(Utils.PtrToStringUTF8(NativeApi.llama_print_system_info())); - } - } - - public LLamaModelV1(LLamaModelV1 other) - { - _ctx = other._ctx; - _model_path = other._model_path; - _params = other._params; - _last_n_tokens_size = other._last_n_tokens_size; - _n_threads = other._n_threads; - _n_batch = other._n_batch; - _verbose = other._verbose; - _lora_base = other._lora_base; - _lora_path = other._lora_path; - _eval_logits = new Queue(other._eval_logits); - _eval_tokens = new Queue(other._eval_tokens); - } - - /// - /// Tokenize a string. - /// - /// The utf-8 encoded string to tokenize. - /// A list of tokens. - /// If the tokenization failed. - public List Tokenize(string text) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - var n_ctx = NativeApi.llama_n_ctx(_ctx); - var tokens = new llama_token[n_ctx]; - var n_tokens = NativeApi.llama_tokenize(_ctx, text, tokens, n_ctx, true); - if(n_tokens < 0) - { - throw new RuntimeError($"Failed to tokenize: text=\"{text}\" n_tokens={n_tokens}"); - } - return tokens.Take(n_tokens).ToList(); - } - - /// - /// Detokenize a list of tokens. - /// - /// The list of tokens to detokenize. - /// The detokenized string. - public string DeTokenize(IEnumerable tokens) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - string output = ""; - foreach(var token in tokens) - { - output += Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)); - } - return output; - } - - public string DeTokenize(llama_token token) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - return Utils.PtrToStringUTF8(NativeApi.llama_token_to_str(_ctx, token)) ?? ""; - } - - /// - /// Set the cache. - /// - /// The cache to set. - public void SetCache(LLamaCache? cache) - { - _cache = cache; - } - - /// - /// Reset the model state. - /// - public void Reset() - { - _eval_tokens.Clear(); - _eval_logits.Clear(); - } - - /// - /// Evaluate a list of tokens. - /// - /// The list of tokens to evaluate. - /// - public unsafe void Eval(List tokens) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - var n_ctx = NativeApi.llama_n_ctx(_ctx); - for(int i = 0; i < tokens.Count; i += _n_batch) - { - var batch = tokens.Take(Math.Min(tokens.Count, i + _n_batch)).Skip(i); - llama_token n_past = Math.Min(n_ctx - batch.Count(), _eval_tokens.Count); - llama_token n_tokens = batch.Count(); - llama_token return_code = NativeApi.llama_eval( - ctx: _ctx, - tokens: batch.ToArray(), - n_tokens: n_tokens, - n_past: n_past, - n_threads: _n_threads - ); - if(return_code != 0) - { - throw new RuntimeError($"llama_eval returned {return_code}"); - } - foreach(var b in batch) - { - _eval_tokens.Enqueue(b); - } - int rows = _params.logits_all ? n_tokens : 1; - llama_token n_vocab = NativeApi.llama_n_vocab(_ctx); - var cols = n_vocab; - var logits_view = NativeApi.llama_get_logits(_ctx); - for(int j = 0; j < rows; j++) - { - float[] logit = new float[cols]; - for(int k = 0; k < cols; k++) - { - logit[k] = logits_view[j * cols + k]; - } - _eval_logits.Enqueue(logit); - } - } - } - - private llama_token SampleInternal(llama_token[] last_n_tokens_data, int last_n_tokens_size, int top_k, - float top_p, float temp, float repeat_penalty, float frequency_penalty, float presence_penalty) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - Debug.Assert(_eval_logits.Count > 0); - llama_token n_vocab = NativeApi.llama_n_vocab(_ctx); - var logits = _eval_logits.Last(); - LLamaTokenData[] data = new LLamaTokenData[n_vocab]; - for(int i = 0; i < n_vocab; i++) - { - data[i] = new LLamaTokenData(i, logits[i], .0f); - } - ulong size = (ulong)n_vocab; - bool sorted = false; - LLamaTokenDataArray candidates = new(data, size, sorted); - SamplingApi.llama_sample_repetition_penalty(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size, - repeat_penalty); - //SamplingApi.llama_sample_frequency_and_presence_penalties(_ctx, candidates, last_n_tokens_data, (ulong)last_n_tokens_size, - // frequency_penalty, presence_penalty); - if(temp == .0f) - { - return SamplingApi.llama_sample_token_greedy(_ctx, candidates); - } - else - { - SamplingApi.llama_sample_top_k(_ctx, candidates, top_k, 1); - SamplingApi.llama_sample_tail_free(_ctx, candidates, 1.0f, 1); - SamplingApi.llama_sample_typical(_ctx, candidates, 1.0f, 1); - SamplingApi.llama_sample_top_p(_ctx, candidates, top_p, 1); - SamplingApi.llama_sample_temperature(_ctx, candidates, temp); - return SamplingApi.llama_sample_token(_ctx, candidates); - } - } - - /// - /// Sample a token from the model. - /// - /// The top-k sampling parameter. - /// The top-p sampling parameter. - /// The temperature parameter. - /// The repeat penalty parameter. - /// - /// - /// The sampled token. - public llama_token Sample(int top_k, float top_p, float temp, float repeat_penalty, float frequency_penalty = .0f, - float presence_penalty = .0f) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - var last_n_tokens_data = Enumerable.Repeat(0, Math.Max(0, _last_n_tokens_size - _eval_tokens.Count)); - last_n_tokens_data = last_n_tokens_data.Concat(_eval_tokens.ToList() - .Skip(Math.Max(0, _eval_tokens.Count - _last_n_tokens_size))); - llama_token[] tokens_data = new llama_token[_last_n_tokens_size]; - int i = 0; - foreach(var data in last_n_tokens_data) - { - if(i < _last_n_tokens_size) - { - tokens_data[i++] = data; - } - else - { - break; - } - } - return SampleInternal(tokens_data, _last_n_tokens_size, top_k, top_p, temp, repeat_penalty, frequency_penalty, presence_penalty); - } - - /// - /// Create a generator of tokens from a prompt. - /// - /// - /// Examples: - /// var llama = new LlamaModel("models/ggml-7b.bin") - /// var tokens = llama.Tokenize(b"Hello, world!") - /// foreach(var token in llama.Generate(tokens, top_k:40, top_p:0.95, temp:1.0, repeat_penalty:1.1)){ - /// Console.WriteLine(llama.DeTokenize(new []{token})); - /// } - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - public IEnumerable Generate(IEnumerable tokens, int top_k, float top_p, float temp, - float repeat_penalty, float frequency_penalty = .0f, float presence_penalty = .0f, bool reset = true) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - if(reset && _eval_tokens.Count > 0) - { - int longest_prefix = 0; - foreach(var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count() - 1), (x, y) => (x, y))) - { - if(a == b) - { - longest_prefix += 1; - } - else - { - break; - } - } - if(longest_prefix > 0) - { - if (_verbose) - { - Logger.Default.Info("Llama.generate: prefix-match hit"); - } - reset = false; - tokens = tokens.Skip(longest_prefix); - for(int i = 0; i < _eval_tokens.Count - longest_prefix; i++) - { - _eval_tokens.Dequeue(); - if(_eval_logits.Count > 0) - { - _eval_logits.Dequeue(); - } - } - } - } - - if (reset) - { - Reset(); - } - - while (true) - { - Eval(tokens.ToList()); - var token = Sample(top_k, top_p, temp, frequency_penalty, presence_penalty, repeat_penalty); - yield return token; - // TODO(Rinne): verify if the implementation is correct. - } - } - - /// - /// Embed a string. - /// - /// The utf-8 encoded string to embed. - /// An embedding object. - /// - public unsafe Embedding CreateEmbedding(string input) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - if (!_params.embedding) - { - throw new RuntimeError("Llama model must be created with embedding=True to call this method"); - } - - if (_verbose) - { - NativeApi.llama_reset_timings(_ctx); - } - - var tokens = Tokenize(input); - Reset(); - Eval(tokens); - int n_tokens = tokens.Count; - var embeddingPtr = NativeApi.llama_get_embeddings(_ctx); - int cnt = NativeApi.llama_n_embd(_ctx); - float[] embedding = new float[cnt]; - for(int i = 0; i < cnt; i++) - { - embedding[i] = embeddingPtr[i]; - } - - if (_verbose) - { - NativeApi.llama_print_timings(_ctx); - } - - return new Embedding("list", _model_path, new[] { new EmbeddingData(0, "embedding", embedding) }, - new EmbeddingUsage(n_tokens, n_tokens)); - } - - public float[] Embed(string input) - { - return CreateEmbedding(input).Data[0].Embedding; - } - - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// - /// IEnumerable of Completion and CompletionChunk - /// - private IEnumerable CreateCompletionInternal(string prompt, string?suffix = null, int max_tokens = 16, float temperature = 0.8f, - float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, - float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - string completionId = $"cmpl-{Guid.NewGuid()}"; - var created = DateTime.Now.Millisecond; - List completionTokens = new List(); - - var promptTokens = Tokenize($" {prompt}"); - string text = ""; - int returnedCharacters = 0; - if(stop is null) - { - stop = new string[0]; - } - - if (_verbose) - { - NativeApi.llama_reset_timings(_ctx); - } - - if(promptTokens.Count + max_tokens > NativeApi.llama_n_ctx(_ctx)) - { - throw new ArgumentException($"Requested tokens exceed context window of {NativeApi.llama_n_ctx(_ctx)}"); - } - - if(logprobs != -1 && !_params.logits_all) - { - throw new ArgumentException("logprobs is not supported for models created with logits_all=False"); - } - - if(_cache is not null) - { - try - { - // TODO(Rinne): revise it since it will compare reference instead of elements. - var cacheItem = _cache[promptTokens.ToArray()]; - var cachePrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens); - var evalPrefixLen = LongestTokenPrefix(_eval_tokens.AsEnumerable(), promptTokens); - if(cachePrefixLen > evalPrefixLen) - { - LoadState(cacheItem); - if (_verbose) - { - Logger.Default.Info("Llama._create_completion: cache hit"); - } - } - } - catch (KeyNotFoundException) - { - if (_verbose) - { - Logger.Default.Warn("Llama._create_completion: cache miss"); - } - } - } - - string finishReason = "length"; - int multibyteFix = 0; - bool reset = true; - List tokens = new(promptTokens); - if (reset && _eval_tokens.Count > 0) - { - int longest_prefix = 0; - foreach (var (a, b) in _eval_tokens.ToList().Zip(tokens.Take(tokens.Count - 1), (x, y) => (x, y))) - { - if (a == b) - { - longest_prefix += 1; - } - else - { - break; - } - } - if (longest_prefix > 0) - { - if (_verbose) - { - Logger.Default.Info("Llama.generate: prefix-match hit"); - } - reset = false; - tokens = tokens.Skip(longest_prefix).ToList(); - for (int i = 0; i < _eval_tokens.Count - longest_prefix; i++) - { - _eval_tokens.Dequeue(); - if (_eval_logits.Count > 0) - { - _eval_logits.Dequeue(); - } - } - } - } - - if (reset) - { - Reset(); - } - //foreach (var token in Generate(promptTokens, top_k, top_p, temperature, frequency_penalty, presence_penalty, repeat_penalty)) - string allText = ""; - while (true) - { - Eval(tokens); - var token = Sample(top_k, top_p, temperature, repeat_penalty, frequency_penalty, presence_penalty); - tokens.Clear(); - tokens.Add(token); - if (token == NativeApi.llama_token_eos()) - { - text = DeTokenize(completionTokens); - finishReason = "stop"; - break; - } - - completionTokens.Add(token); - - allText = DeTokenize(completionTokens); - - int cut = Math.Min(3, allText.Length); - for(int i = allText.Length - cut; i < allText.Length; i++) - { - var c = (int)allText[i]; - int k = cut - i; - foreach(var (num, pattern) in _numAndPatterns) - { - if(num > k && (pattern & c) == pattern) - { - multibyteFix = num - k; - } - } - } - - if(multibyteFix > 0) - { - multibyteFix--; - continue; - } - - var anyStop = stop.Where(s => allText.Contains(s)); - if(anyStop.Count() > 0) - { - var firstStop = anyStop.First(); - text = allText.Substring(0, allText.IndexOf(firstStop)); - finishReason = "stop"; - break; - } - - var start = returnedCharacters; - int longest = 0; - foreach (var s in stop) - { - for (int i = s.Length; i > 0; i--) - { - if (allText.EndsWith(s.Substring(0, i))) - { - if (i > longest) - { - longest = i; - } - break; - } - } - } - text = allText.Substring(0, allText.Length - longest); - returnedCharacters += text.Skip(start).Count(); - yield return new CompletionChunk(completionId, "text_completion", created, _model_path, new CompletionChoice[] - { - new CompletionChoice(text.Substring(start), 0, null, finishReason) - }); - } - - if (_cache is not null) - { - if (_verbose) - { - Logger.Default.Info("Llama._create_completion: cache save"); - } - _cache[promptTokens.Concat(completionTokens).ToArray()] = SaveState(); - } - - string textStr = text; - if (echo) - { - textStr = prompt + textStr; - } - if(suffix is not null) - { - textStr = textStr + suffix; - } - - CompletionLogprobs? logProbs = null; - if (logprobs != -1) - { - int textOffset = 0; - List textOffsets = new(); - List tokenLogprobs = new(); - List tokenStrs = new(); - List> topLogprobs = new(); - - var allTokens = promptTokens.Concat(completionTokens).ToArray(); - var allTokenStrs = allTokens.Select(t => DeTokenize(new[] { t })); - var allLogProbs = _eval_logits.Select(row => LogitsToLogprobs(row)); - - foreach (var (token, tokenStr, logProbsToken) in allTokens.Zip(allTokenStrs, (x, y) => (x, y)) - .Zip(allLogProbs, (x, y) => (x.x, x.y, y))) - { - textOffsets.Add(textOffset); - textOffset += tokenStr.Length; - tokenStrs.Add(tokenStr); - var sortedLogprobs = logProbsToken.Zip(Enumerable.Range(0, logProbsToken.Count()), (x, y) => (x, y)) - .OrderByDescending(x => x.x).ToList(); - tokenLogprobs.Add(sortedLogprobs[token].x); - var topLogprob = sortedLogprobs.Take(logprobs).ToDictionary(t => DeTokenize(new[] { t.y }), t => t.x); - topLogprob[tokenStr] = sortedLogprobs[token].x; - topLogprobs.Add(topLogprob); - } - - logProbs = new(textOffsets.ToArray(), tokenLogprobs.ToArray(), tokenStrs.ToArray(), topLogprobs.ToArray()); - } - - if (_verbose) - { - NativeApi.llama_print_timings(_ctx); - } - } - - /// - /// Generate text from a prompt and yield return the result. - /// - /// The prompt to generate text from. - /// A suffix to append to the generated text. If None, no suffix is appended. - /// The maximum number of tokens to generate. - /// The temperature to use for sampling. - /// The top-p value to use for sampling. - /// The number of logprobs to return. If None, no logprobs are returned. - /// Whether to echo the prompt. - /// A list of strings to stop generation when encountered. - /// - /// - /// The penalty to apply to repeated tokens. - /// The top-k value to use for sampling. - /// - public IEnumerable CreateCompletion(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, - float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, - float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) - { - return CreateCompletionInternal(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, - frequency_penalty, presence_penalty, repeat_penalty, top_k); - } - - /// - /// Generate text from a prompt and yield return the result. - /// - /// The prompt to generate text from. - /// A suffix to append to the generated text. If None, no suffix is appended. - /// The maximum number of tokens to generate. - /// The temperature to use for sampling. - /// The top-p value to use for sampling. - /// The number of logprobs to return. If None, no logprobs are returned. - /// Whether to echo the prompt. - /// A list of strings to stop generation when encountered. - /// - /// - /// The penalty to apply to repeated tokens. - /// The top-k value to use for sampling. - /// - public IEnumerable Call(string prompt, string? suffix = null, int max_tokens = 128, float temperature = 0.8f, - float top_p = 0.95f, int logprobs = -1, bool echo = false, string[]? stop = null, float frequency_penalty = .0f, - float presence_penalty = .0f, float repeat_penalty = 1.1f, int top_k = 40) - { - return CreateCompletion(prompt, suffix, max_tokens, temperature, top_p, logprobs, echo, stop, - frequency_penalty, presence_penalty, repeat_penalty, top_k); - } - - private ChatCompletion ConvertTextCompletionToChat(Completion completion) - { - return new ChatCompletion($"chat{completion.Id}", "chat.completion", completion.Created, completion.Model, - new[] { new ChatCompletionChoice(0, new ChatCompletionMessage(ChatRole.Assistant, completion.Choices[0].Text), - completion.Choices[0].FinishReason) }, completion.Usage); - } - - private IEnumerable ConvertTextCompletionChunksToChat(IEnumerable chunks) - { - bool isFirst = true; - foreach(var chunk in chunks) - { - if(isFirst) - { - yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created, - new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta("assistant", null), null) }); - isFirst = false; - } - yield return new ChatCompletionChunk($"chat{chunk.Id}", chunk.Model, "chat.completion.chunk", chunk.Created, - new[] { new ChatCompletionChunkChoice(0, new ChatCompletionChunkDelta(null, chunk.Choices[0].Text), - chunk.Choices[0].FinishReason) }); - } - } - - /// - /// Generate a chat completion from a list of messages and yield return the result. - /// - /// A list of messages to generate a response for. - /// The temperature to use for sampling. - /// The top-p value to use for sampling. - /// The top-k value to use for sampling. - /// A list of strings to stop generation when encountered. - /// The maximum number of tokens to generate. - /// - /// - /// The penalty to apply to repeated tokens. - /// - public IEnumerable CreateChatCompletion(IEnumerable messages, float temperature = .2f, float top_p = .95f, - int top_k = 40, string[]? stop = null, int max_tokens = 256, float presence_penalty = .0f, float frequency_penalty = .0f, - float repeat_penalty = 1.1f) - { - if (stop is null) - { - stop = new string[0]; - } - string GetRole(ChatCompletionMessage message) - { - return message.Role == ChatRole.Human ? "Human" : "Assistant"; - } - string chatHistory = string.Join("", messages.Select(m => $"### {GetRole(m)}:{m.Content}")); - var prompt = chatHistory + "### Assistant:"; - prompt = prompt.Substring(Math.Max(0, prompt.Length - max_tokens)); - var promptStop = new[] { "### Assistant:", "### Human:" }.Concat(stop).ToArray(); - var completion = Call(prompt, stop: promptStop, temperature: temperature, top_p: top_p, top_k: top_k, max_tokens: max_tokens, - repeat_penalty: repeat_penalty, presence_penalty: presence_penalty, frequency_penalty: frequency_penalty); - return ConvertTextCompletionChunksToChat(completion); - } - - public LLamaState SaveState() - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - ulong stateSize = NativeApi.llama_get_state_size(_ctx); - byte[] llamaState = new byte[stateSize]; - ulong nBytes = NativeApi.llama_copy_state_data(_ctx, llamaState); - if(nBytes > stateSize) - { - throw new RuntimeError("Failed to copy llama state data"); - } - byte[] llamaStateCompact = new byte[nBytes]; - llamaState.Take((int)nBytes).ToArray().CopyTo(llamaStateCompact, 0); - if (_verbose) - { - Logger.Default.Info($"Llama.save_state: saving {nBytes} bytes of llama state"); - } - return new LLamaState(new Queue(_eval_tokens), new Queue(_eval_logits), - llamaStateCompact, (int)nBytes); - } - - public void LoadState(LLamaState state) - { - Debug.Assert(_ctx.DangerousGetHandle() != IntPtr.Zero); - _eval_tokens = new Queue(state.EvalTokens); - _eval_logits = new Queue(state.EvalLogits); - if(NativeApi.llama_set_state_data(_ctx, state.State) != (ulong)state.Size) - { - throw new RuntimeError($"Failed to set llama state data"); - } - } - - private static IEnumerable LogitsToLogprobs(IEnumerable logits) - { - var exps = logits.Select(x => (float)Math.Exp(x)); - var sumExps = exps.Sum(); - return exps.Select(x => (float)Math.Log(x / sumExps)); - } - - internal static int LongestTokenPrefix(IEnumerable a, IEnumerable b) - { - int longestPrefix = 0; - foreach(var (x, y) in a.Zip(b, (x, y) => (x, y))) - { - if(x == y) - { - longestPrefix++; - } - else - { - break; - } - } - return longestPrefix; - } - - public void Dispose() - { - _ctx.Dispose(); - } - } -} diff --git a/LLama/LLamaState.cs b/LLama/LLamaState.cs deleted file mode 100644 index 397d3431..00000000 --- a/LLama/LLamaState.cs +++ /dev/null @@ -1,10 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama -{ - using llama_token = Int32; - public record LLamaState(Queue EvalTokens, Queue EvalLogits, - byte[] State, int Size); -} diff --git a/LLama/LLamaTypes.cs b/LLama/LLamaTypes.cs deleted file mode 100644 index 16da42e4..00000000 --- a/LLama/LLamaTypes.cs +++ /dev/null @@ -1,41 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Text; - -namespace LLama.Types -{ - public enum ChatRole - { - Human, - Assistant - } - public record EmbeddingUsage(int PromptTokens, int TotalTokens); - - public record EmbeddingData(int Index, string Object, float[] Embedding); - - public record Embedding(string Object, string Model, EmbeddingData[] Data, EmbeddingUsage Usage); - - public record CompletionLogprobs(int[] TextOffset, float[] TokenLogProbs, string[] Tokens, Dictionary[] TopLogprobs); - - public record CompletionChoice(string Text, int Index, CompletionLogprobs? Logprobs, string? FinishReason); - - public record CompletionUsage(int PromptTokens, int CompletionTokens, int TotalTokens); - - public record CompletionChunk(string Id, string Object, int Created, string Model, CompletionChoice[] Choices); - - public record Completion(string Id, string Object, int Created, string Model, CompletionChoice[] Choices, CompletionUsage Usage); - - public record ChatCompletionMessage(ChatRole Role, string Content, string? Name = null); - - public record ChatCompletionChoice(int Index, ChatCompletionMessage Message, string? FinishReason); - - public record ChatCompletion(string Id, string Object, int Created, string Model, ChatCompletionChoice[] Choices, CompletionUsage Usage); - - public record ChatCompletionChunkDelta(string? Role, string? Content); - - public record ChatCompletionChunkChoice(int Index, ChatCompletionChunkDelta Delta, string? FinishReason); - - public record ChatCompletionChunk(string Id, string Model, string Object, int Created, ChatCompletionChunkChoice[] Choices); - - public record ChatMessageRecord(ChatCompletionMessage Message, DateTime Time); -}