| @@ -6,7 +6,10 @@ using LLama.Native; | |||
| namespace LLama.Extensions | |||
| { | |||
| internal static class IModelParamsExtensions | |||
| /// <summary> | |||
| /// Extention methods to the IModelParams interface | |||
| /// </summary> | |||
| public static class IModelParamsExtensions | |||
| { | |||
| /// <summary> | |||
| /// Convert the given `IModelParams` into a `LLamaContextParams` | |||
| @@ -31,7 +34,7 @@ namespace LLama.Extensions | |||
| result.n_gpu_layers = @params.GpuLayerCount; | |||
| result.seed = @params.Seed; | |||
| result.f16_kv = @params.UseFp16Memory; | |||
| result.use_mmap = @params.UseMemoryLock; | |||
| result.use_mmap = @params.UseMemorymap; | |||
| result.use_mlock = @params.UseMemoryLock; | |||
| result.logits_all = @params.Perplexity; | |||
| result.embedding = @params.EmbeddingMode; | |||
| @@ -55,7 +55,7 @@ namespace LLama | |||
| text = text.Insert(0, " "); | |||
| } | |||
| var embed_inp_array = Utils.Tokenize(_ctx, text, addBos, Encoding.GetEncoding(encoding)).ToArray(); | |||
| var embed_inp_array = _ctx.Tokenize(text, addBos, Encoding.GetEncoding(encoding)); | |||
| // TODO(Rinne): deal with log of prompt | |||
| @@ -68,9 +68,9 @@ namespace LLama | |||
| public LLamaModel Model => _model; | |||
| /// <summary> | |||
| /// Current "mu" value for mirostate sampling | |||
| /// Current "mu" value for mirostat sampling | |||
| /// </summary> | |||
| protected float MirostateMu { get; set; } = float.NaN; | |||
| protected float? MirostatMu { get; set; } | |||
| /// <summary> | |||
| /// | |||
| @@ -391,8 +391,8 @@ namespace LLama | |||
| [JsonPropertyName("last_tokens_maximum_count")] | |||
| public int LastTokensCapacity { get; set; } | |||
| [JsonPropertyName("mirostate_mu")] | |||
| public float MirostateMu { get; set; } | |||
| [JsonPropertyName("mirostat_mu")] | |||
| public float? MirostatMu { get; set; } | |||
| } | |||
| } | |||
| } | |||
| @@ -30,8 +30,8 @@ namespace LLama | |||
| public InstructExecutor(LLamaModel model, string instructionPrefix = "\n\n### Instruction:\n\n", | |||
| string instructionSuffix = "\n\n### Response:\n\n") : base(model) | |||
| { | |||
| _inp_pfx = _model.Tokenize(instructionPrefix, true).ToArray(); | |||
| _inp_sfx = _model.Tokenize(instructionSuffix, false).ToArray(); | |||
| _inp_pfx = _model.Tokenize(instructionPrefix, true); | |||
| _inp_sfx = _model.Tokenize(instructionSuffix, false); | |||
| _instructionPrefix = instructionPrefix; | |||
| } | |||
| @@ -53,7 +53,7 @@ namespace LLama | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostateMu = MirostateMu | |||
| MirostatMu = MirostatMu | |||
| }; | |||
| return state; | |||
| } | |||
| @@ -133,7 +133,7 @@ namespace LLama | |||
| _embed_inps.AddRange(_inp_sfx); | |||
| args.RemainedTokens -= line_inp.Count(); | |||
| args.RemainedTokens -= line_inp.Length; | |||
| } | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -146,9 +146,7 @@ namespace LLama | |||
| { | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| } | |||
| last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); | |||
| foreach (var antiprompt in args.Antiprompts) | |||
| { | |||
| @@ -216,12 +214,12 @@ namespace LLama | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostateMu; | |||
| var mu = MirostatMu; | |||
| var id = _model.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| ); | |||
| MirostateMu = mu; | |||
| MirostatMu = mu; | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -25,7 +25,7 @@ namespace LLama | |||
| /// <param name="model"></param> | |||
| public InteractiveExecutor(LLamaModel model) : base(model) | |||
| { | |||
| _llama_token_newline = Utils.Tokenize(_model.NativeHandle, "\n", false, _model.Encoding).ToArray(); | |||
| _llama_token_newline = _model.NativeHandle.Tokenize("\n", false, _model.Encoding); | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -45,7 +45,7 @@ namespace LLama | |||
| SessionFilePath = _pathSession, | |||
| SessionTokens = _session_tokens, | |||
| LastTokensCapacity = _last_n_tokens.Capacity, | |||
| MirostateMu = MirostateMu | |||
| MirostatMu = MirostatMu | |||
| }; | |||
| return state; | |||
| } | |||
| @@ -114,7 +114,7 @@ namespace LLama | |||
| } | |||
| var line_inp = _model.Tokenize(text, false); | |||
| _embed_inps.AddRange(line_inp); | |||
| args.RemainedTokens -= line_inp.Count(); | |||
| args.RemainedTokens -= line_inp.Length; | |||
| } | |||
| } | |||
| @@ -133,7 +133,7 @@ namespace LLama | |||
| string last_output = ""; | |||
| foreach (var id in _last_n_tokens) | |||
| { | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, id), _model.Encoding); | |||
| last_output += _model.NativeHandle.TokenToString(id, _model.Encoding); | |||
| } | |||
| foreach (var antiprompt in args.Antiprompts) | |||
| @@ -203,12 +203,12 @@ namespace LLama | |||
| var tokenDataArray = _model.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n, | |||
| inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); | |||
| var mu = MirostateMu; | |||
| var mu = MirostatMu; | |||
| var id = _model.Sample( | |||
| tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, | |||
| inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP | |||
| ); | |||
| MirostateMu = mu; | |||
| MirostatMu = mu; | |||
| _last_n_tokens.Enqueue(id); | |||
| @@ -78,10 +78,9 @@ namespace LLama | |||
| /// <param name="text"></param> | |||
| /// <param name="addBos">Whether to add a bos to the text.</param> | |||
| /// <returns></returns> | |||
| public IEnumerable<llama_token> Tokenize(string text, bool addBos = true) | |||
| public llama_token[] Tokenize(string text, bool addBos = true) | |||
| { | |||
| // TODO: reconsider whether to convert to array here. | |||
| return Utils.Tokenize(_ctx, text, addBos, _encoding); | |||
| return _ctx.Tokenize(text, addBos, _encoding); | |||
| } | |||
| /// <summary> | |||
| @@ -93,9 +92,7 @@ namespace LLama | |||
| { | |||
| StringBuilder sb = new(); | |||
| foreach(var token in tokens) | |||
| { | |||
| sb.Append(Utils.PtrToString(NativeApi.llama_token_to_str(_ctx, token), _encoding)); | |||
| } | |||
| sb.Append(_ctx.TokenToString(token, _encoding)); | |||
| return sb.ToString(); | |||
| } | |||
| @@ -245,7 +242,7 @@ namespace LLama | |||
| /// <param name="tfsZ"></param> | |||
| /// <param name="typicalP"></param> | |||
| /// <returns></returns> | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | |||
| public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, | |||
| float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f) | |||
| { | |||
| llama_token id; | |||
| @@ -256,30 +253,31 @@ namespace LLama | |||
| } | |||
| else | |||
| { | |||
| if (float.IsNaN(mirostat_mu)) | |||
| mirostat_mu = 2 * mirostatTau; | |||
| if (mirostat == MirostatType.Mirostat) | |||
| { | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mirostat_mu); | |||
| } | |||
| else if (mirostat == MirostatType.Mirostat2) | |||
| var mu = mirostat_mu ?? (2 * mirostatTau); | |||
| { | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mirostat_mu); | |||
| } | |||
| else | |||
| { | |||
| // Temperature sampling | |||
| SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(_ctx, candidates); | |||
| if (mirostat == MirostatType.Mirostat) | |||
| { | |||
| const int mirostat_m = 100; | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat(_ctx, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); | |||
| } | |||
| else if (mirostat == MirostatType.Mirostat2) | |||
| { | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token_mirostat_v2(_ctx, candidates, mirostatTau, mirostatEta, ref mu); | |||
| } | |||
| else | |||
| { | |||
| // Temperature sampling | |||
| SamplingApi.llama_sample_top_k(_ctx, candidates, topK, 1); | |||
| SamplingApi.llama_sample_tail_free(_ctx, candidates, tfsZ, 1); | |||
| SamplingApi.llama_sample_typical(_ctx, candidates, typicalP, 1); | |||
| SamplingApi.llama_sample_top_p(_ctx, candidates, topP, 1); | |||
| SamplingApi.llama_sample_temperature(_ctx, candidates, temperature); | |||
| id = SamplingApi.llama_sample_token(_ctx, candidates); | |||
| } | |||
| } | |||
| mirostat_mu = mu; | |||
| } | |||
| return id; | |||
| } | |||
| @@ -299,8 +297,8 @@ namespace LLama | |||
| int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f, | |||
| bool penalizeNL = true) | |||
| { | |||
| var n_vocab = NativeApi.llama_n_vocab(_ctx); | |||
| var logits = Utils.GetLogits(_ctx, n_vocab); | |||
| var n_vocab = _ctx.VocabCount; | |||
| var logits = _ctx.GetLogits(); | |||
| // Apply params.logit_bias map | |||
| if(logitBias is not null) | |||
| @@ -352,7 +350,7 @@ namespace LLama | |||
| n_eval = Params.BatchSize; | |||
| } | |||
| if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) | |||
| if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) | |||
| { | |||
| _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); | |||
| throw new RuntimeError("Failed to eval."); | |||
| @@ -367,9 +365,7 @@ namespace LLama | |||
| internal IEnumerable<string> GenerateResult(IEnumerable<llama_token> ids) | |||
| { | |||
| foreach(var id in ids) | |||
| { | |||
| yield return Utils.TokenToString(id, _ctx, _encoding); | |||
| } | |||
| yield return _ctx.TokenToString(id, _encoding); | |||
| } | |||
| /// <inheritdoc /> | |||
| @@ -31,7 +31,7 @@ namespace LLama | |||
| _model = model; | |||
| var tokens = model.Tokenize(" ", true).ToArray(); | |||
| Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads); | |||
| _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); | |||
| _originalState = model.GetState(); | |||
| } | |||
| @@ -52,12 +52,12 @@ namespace LLama | |||
| List<llama_token> tokens = _model.Tokenize(text, true).ToList(); | |||
| int n_prompt_tokens = tokens.Count; | |||
| Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads); | |||
| _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); | |||
| lastTokens.AddRange(tokens); | |||
| n_past += n_prompt_tokens; | |||
| var mu = float.NaN; | |||
| var mu = (float?)null; | |||
| int max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; | |||
| for(int i = 0; i < max_tokens; i++) | |||
| { | |||
| @@ -76,7 +76,7 @@ namespace LLama | |||
| lastTokens.Add(id); | |||
| string response = Utils.TokenToString(id, _model.NativeHandle, _model.Encoding); | |||
| string response = _model.NativeHandle.TokenToString(id, _model.Encoding); | |||
| yield return response; | |||
| tokens.Clear(); | |||
| @@ -87,7 +87,7 @@ namespace LLama | |||
| string last_output = ""; | |||
| foreach (var token in lastTokens) | |||
| { | |||
| last_output += Utils.PtrToString(NativeApi.llama_token_to_str(_model.NativeHandle, token), _model.Encoding); | |||
| last_output += _model.NativeHandle.TokenToString(token, _model.Encoding); | |||
| } | |||
| bool should_break = false; | |||
| @@ -207,6 +207,17 @@ namespace LLama.Native | |||
| [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); | |||
| /// <summary> | |||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | |||
| /// tokens + n_tokens is the provided batch of new tokens to process | |||
| /// n_past is the number of tokens to use from previous eval calls | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="tokens"></param> | |||
| /// <param name="n_tokens"></param> | |||
| /// <param name="n_past"></param> | |||
| /// <param name="n_threads"></param> | |||
| /// <returns>Returns 0 on success</returns> | |||
| [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] | |||
| public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); | |||
| @@ -218,6 +229,7 @@ namespace LLama.Native | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <param name="text"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <param name="tokens"></param> | |||
| /// <param name="n_max_tokens"></param> | |||
| /// <param name="add_bos"></param> | |||
| @@ -256,8 +268,8 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Token logits obtained from the last call to llama_eval() | |||
| /// The logits for the last token are stored in the last row | |||
| /// Can be mutated in order to change the probabilities of the next token | |||
| /// Rows: n_tokens | |||
| /// Can be mutated in order to change the probabilities of the next token.<br /> | |||
| /// Rows: n_tokens<br /> | |||
| /// Cols: n_vocab | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| @@ -1,4 +1,6 @@ | |||
| using System; | |||
| using System.Buffers; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| namespace LLama.Native | |||
| @@ -9,11 +11,29 @@ namespace LLama.Native | |||
| public class SafeLLamaContextHandle | |||
| : SafeLLamaHandleBase | |||
| { | |||
| #region properties and fields | |||
| /// <summary> | |||
| /// Total number of tokens in vocabulary of this model | |||
| /// </summary> | |||
| public int VocabCount => ThrowIfDisposed().VocabCount; | |||
| /// <summary> | |||
| /// Total number of tokens in the context | |||
| /// </summary> | |||
| public int ContextSize => ThrowIfDisposed().ContextSize; | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| public int EmbeddingCount => ThrowIfDisposed().EmbeddingCount; | |||
| /// <summary> | |||
| /// This field guarantees that a reference to the model is held for as long as this handle is held | |||
| /// </summary> | |||
| private SafeLlamaModelHandle? _model; | |||
| #endregion | |||
| #region construction/destruction | |||
| /// <summary> | |||
| /// Create a new SafeLLamaContextHandle | |||
| /// </summary> | |||
| @@ -42,6 +62,16 @@ namespace LLama.Native | |||
| return true; | |||
| } | |||
| private SafeLlamaModelHandle ThrowIfDisposed() | |||
| { | |||
| if (IsClosed) | |||
| throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); | |||
| if (_model == null || _model.IsClosed) | |||
| throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - `SafeLlamaModelHandle` has been disposed"); | |||
| return _model; | |||
| } | |||
| /// <summary> | |||
| /// Create a new llama_state for the given model | |||
| /// </summary> | |||
| @@ -57,5 +87,103 @@ namespace LLama.Native | |||
| return new(ctx_ptr, model); | |||
| } | |||
| #endregion | |||
| /// <summary> | |||
| /// Convert the given text into tokens | |||
| /// </summary> | |||
| /// <param name="text">The text to tokenize</param> | |||
| /// <param name="add_bos">Whether the "BOS" token should be added</param> | |||
| /// <param name="encoding">Encoding to use for the text</param> | |||
| /// <returns></returns> | |||
| /// <exception cref="RuntimeError"></exception> | |||
| public int[] Tokenize(string text, bool add_bos, Encoding encoding) | |||
| { | |||
| ThrowIfDisposed(); | |||
| // Calculate number of bytes in string, this is a pessimistic estimate of token count. It can't | |||
| // possibly be more than this. | |||
| var count = encoding.GetByteCount(text) + (add_bos ? 1 : 0); | |||
| // "Rent" an array to write results into (avoiding an allocation of a large array) | |||
| var temporaryArray = ArrayPool<int>.Shared.Rent(count); | |||
| try | |||
| { | |||
| // Do the actual conversion | |||
| var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); | |||
| if (n < 0) | |||
| { | |||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | |||
| "specify the encoding."); | |||
| } | |||
| // Copy the results from the rented into an array which is exactly the right size | |||
| var result = new int[n]; | |||
| Array.ConstrainedCopy(temporaryArray, 0, result, 0, n); | |||
| return result; | |||
| } | |||
| finally | |||
| { | |||
| ArrayPool<int>.Shared.Return(temporaryArray); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Token logits obtained from the last call to llama_eval() | |||
| /// The logits for the last token are stored in the last row | |||
| /// Can be mutated in order to change the probabilities of the next token.<br /> | |||
| /// Rows: n_tokens<br /> | |||
| /// Cols: n_vocab | |||
| /// </summary> | |||
| /// <param name="ctx"></param> | |||
| /// <returns></returns> | |||
| public Span<float> GetLogits() | |||
| { | |||
| var model = ThrowIfDisposed(); | |||
| unsafe | |||
| { | |||
| var logits = NativeApi.llama_get_logits(this); | |||
| return new Span<float>(logits, model.VocabCount); | |||
| } | |||
| } | |||
| /// <summary> | |||
| /// Convert a token into a string | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| /// <param name="encoding"></param> | |||
| /// <returns></returns> | |||
| public string TokenToString(int token, Encoding encoding) | |||
| { | |||
| return ThrowIfDisposed().TokenToString(token, encoding); | |||
| } | |||
| /// <summary> | |||
| /// Convert a token into a span of bytes that could be decoded into a string | |||
| /// </summary> | |||
| /// <param name="token"></param> | |||
| /// <returns></returns> | |||
| public ReadOnlySpan<byte> TokenToSpan(int token) | |||
| { | |||
| return ThrowIfDisposed().TokenToSpan(token); | |||
| } | |||
| /// <summary> | |||
| /// Run the llama inference to obtain the logits and probabilities for the next token. | |||
| /// </summary> | |||
| /// <param name="tokens">The provided batch of new tokens to process</param> | |||
| /// <param name="n_past">the number of tokens to use from previous eval calls</param> | |||
| /// <param name="n_threads"></param> | |||
| /// <returns>Returns true on success</returns> | |||
| public bool Eval(Memory<int> tokens, int n_past, int n_threads) | |||
| { | |||
| using var pin = tokens.Pin(); | |||
| unsafe | |||
| { | |||
| return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -13,17 +13,17 @@ namespace LLama.Native | |||
| /// <summary> | |||
| /// Total number of tokens in vocabulary of this model | |||
| /// </summary> | |||
| public int VocabCount { get; set; } | |||
| public int VocabCount { get; } | |||
| /// <summary> | |||
| /// Total number of tokens in the context | |||
| /// </summary> | |||
| public int ContextSize { get; set; } | |||
| public int ContextSize { get; } | |||
| /// <summary> | |||
| /// Dimension of embedding vectors | |||
| /// </summary> | |||
| public int EmbeddingCount { get; set; } | |||
| public int EmbeddingCount { get; } | |||
| internal SafeLlamaModelHandle(IntPtr handle) | |||
| : base(handle) | |||
| @@ -2,10 +2,8 @@ | |||
| using LLama.Native; | |||
| using System; | |||
| using System.Collections.Generic; | |||
| using System.Linq; | |||
| using System.Runtime.InteropServices; | |||
| using System.Text; | |||
| using LLama.Exceptions; | |||
| using LLama.Extensions; | |||
| namespace LLama | |||
| @@ -27,41 +25,36 @@ namespace LLama | |||
| } | |||
| } | |||
| [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] | |||
| public static IEnumerable<llama_token> Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) | |||
| { | |||
| var cnt = encoding.GetByteCount(text); | |||
| llama_token[] res = new llama_token[cnt + (add_bos ? 1 : 0)]; | |||
| int n = NativeApi.llama_tokenize(ctx, text, encoding, res, res.Length, add_bos); | |||
| if (n < 0) | |||
| { | |||
| throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + | |||
| "specify the encoding."); | |||
| } | |||
| return res.Take(n); | |||
| return ctx.Tokenize(text, add_bos, encoding); | |||
| } | |||
| public static unsafe Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||
| [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] | |||
| public static Span<float> GetLogits(SafeLLamaContextHandle ctx, int length) | |||
| { | |||
| var logits = NativeApi.llama_get_logits(ctx); | |||
| return new Span<float>(logits, length); | |||
| if (length != ctx.VocabCount) | |||
| throw new ArgumentException("length must be the VocabSize"); | |||
| return ctx.GetLogits(); | |||
| } | |||
| public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) | |||
| [Obsolete("Use SafeLLamaContextHandle Eval method instead")] | |||
| public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) | |||
| { | |||
| int result; | |||
| fixed(llama_token* p = tokens) | |||
| { | |||
| result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads); | |||
| } | |||
| return result; | |||
| var slice = tokens.AsMemory().Slice(startIndex, n_tokens); | |||
| return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; | |||
| } | |||
| [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] | |||
| public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) | |||
| { | |||
| return PtrToString(NativeApi.llama_token_to_str(ctx, token), encoding); | |||
| return ctx.TokenToString(token, encoding); | |||
| } | |||
| public static unsafe string PtrToString(IntPtr ptr, Encoding encoding) | |||
| [Obsolete("No longer used internally by LlamaSharp")] | |||
| public static string PtrToString(IntPtr ptr, Encoding encoding) | |||
| { | |||
| #if NET6_0_OR_GREATER | |||
| if(encoding == Encoding.UTF8) | |||
| @@ -77,21 +70,24 @@ namespace LLama | |||
| return Marshal.PtrToStringAuto(ptr); | |||
| } | |||
| #else | |||
| byte* tp = (byte*)ptr.ToPointer(); | |||
| List<byte> bytes = new(); | |||
| while (true) | |||
| unsafe | |||
| { | |||
| byte c = *tp++; | |||
| if (c == '\0') | |||
| { | |||
| break; | |||
| } | |||
| else | |||
| byte* tp = (byte*)ptr.ToPointer(); | |||
| List<byte> bytes = new(); | |||
| while (true) | |||
| { | |||
| bytes.Add(c); | |||
| byte c = *tp++; | |||
| if (c == '\0') | |||
| { | |||
| break; | |||
| } | |||
| else | |||
| { | |||
| bytes.Add(c); | |||
| } | |||
| } | |||
| return encoding.GetString(bytes.ToArray()); | |||
| } | |||
| return encoding.GetString(bytes.ToArray()); | |||
| #endif | |||
| } | |||
| } | |||