diff --git a/LLama/LLamaModel.cs b/LLama/LLamaModel.cs index 244d5f57..526ffcb2 100644 --- a/LLama/LLamaModel.cs +++ b/LLama/LLamaModel.cs @@ -335,7 +335,7 @@ namespace LLama n_eval = Params.BatchSize; } - if(Utils.Eval(_ctx, tokens, i, n_eval, pastTokensCount, Params.Threads) != 0) + if (!_ctx.Eval(tokens.AsMemory(i, n_eval), pastTokensCount, Params.Threads)) { _logger?.Log(nameof(LLamaModel), "Failed to eval.", ILLamaLogger.LogLevel.Error); throw new RuntimeError("Failed to eval."); diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index a3b41df4..3ae9ed54 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -31,7 +31,7 @@ namespace LLama _model = model; var tokens = model.Tokenize(" ", true).ToArray(); - Utils.Eval(_model.NativeHandle, tokens, 0, tokens.Length, 0, _model.Params.Threads); + _model.NativeHandle.Eval(tokens.AsMemory(0, tokens.Length), 0, _model.Params.Threads); _originalState = model.GetState(); } @@ -52,7 +52,7 @@ namespace LLama List tokens = _model.Tokenize(text, true).ToList(); int n_prompt_tokens = tokens.Count; - Utils.Eval(_model.NativeHandle, tokens.ToArray(), 0, n_prompt_tokens, n_past, _model.Params.Threads); + _model.NativeHandle.Eval(tokens.ToArray().AsMemory(0, n_prompt_tokens), n_past, _model.Params.Threads); lastTokens.AddRange(tokens); n_past += n_prompt_tokens; diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 1239da27..edfb4152 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -207,6 +207,17 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); + /// + /// Run the llama inference to obtain the logits and probabilities for the next token. + /// tokens + n_tokens is the provided batch of new tokens to process + /// n_past is the number of tokens to use from previous eval calls + /// + /// + /// + /// + /// + /// + /// Returns 0 on success [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 8940262b..fa54f73e 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -169,5 +169,21 @@ namespace LLama.Native { return ThrowIfDisposed().TokenToSpan(token); } + + /// + /// Run the llama inference to obtain the logits and probabilities for the next token. + /// + /// The provided batch of new tokens to process + /// the number of tokens to use from previous eval calls + /// + /// Returns true on success + public bool Eval(Memory tokens, int n_past, int n_threads) + { + using var pin = tokens.Pin(); + unsafe + { + return NativeApi.llama_eval_with_pointer(this, (int*)pin.Pointer, tokens.Length, n_past, n_threads) == 0; + } + } } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs index 317d8fdd..98dcbd0a 100644 --- a/LLama/Utils.cs +++ b/LLama/Utils.cs @@ -43,14 +43,11 @@ namespace LLama return ctx.GetLogits(); } - public static unsafe int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) + [Obsolete("Use SafeLLamaContextHandle Eval method instead")] + public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) { - int result; - fixed(llama_token* p = tokens) - { - result = NativeApi.llama_eval_with_pointer(ctx, p + startIndex, n_tokens, n_past, n_threads); - } - return result; + var slice = tokens.AsMemory().Slice(startIndex, n_tokens); + return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; } [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")]