diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs index b48a1ccd..ff4e8f79 100644 --- a/LLama.Examples/NewVersion/BatchedDecoding.cs +++ b/LLama.Examples/NewVersion/BatchedDecoding.cs @@ -15,9 +15,9 @@ public class BatchedDecoding private const int n_parallel = 8; private const int n_len = 32; - private const int top_k = 40; - private const float top_p = 0.9f; - private const float temp = 0.4f; + private const int top_k = 80; + private const float top_p = 0.8f; + private const float temp = 0.5f; public static async Task Run() { @@ -57,7 +57,7 @@ public class BatchedDecoding // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) - llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false); + llama_batch_add(batch, prompt_tokens[i], i, new List { (LLamaSeqId)0 }, false); Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length); // llama_decode will output logits only for the last token of the prompt @@ -66,7 +66,7 @@ public class BatchedDecoding batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; } - if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) + if (context.NativeHandle.Decode(batch) != 0) { await Console.Error.WriteLineAsync("llama_decode failed"); return; @@ -91,8 +91,8 @@ public class BatchedDecoding for (var i = 0; i < n_parallel; i++) i_batch.Add(batch.NativeBatch.n_tokens - 1); - int n_cur = batch.NativeBatch.n_tokens; - int n_decode = 0; + var n_cur = batch.NativeBatch.n_tokens; + var n_decode = 0; var streams = new List[n_parallel]; for (var i = 0; i < n_parallel; i++) @@ -137,7 +137,7 @@ public class BatchedDecoding i_batch[i] = batch.NativeBatch.n_tokens; // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true); + llama_batch_add(batch, new_token_id, n_cur, new List { (LLamaSeqId)i }, true); n_decode++; } @@ -151,7 +151,7 @@ public class BatchedDecoding n_cur++; // evaluate the current batch with the transformer model - if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) + if (context.NativeHandle.Decode(batch) != 0) { await Console.Error.WriteLineAsync("failed to eval"); return; @@ -179,7 +179,7 @@ public class BatchedDecoding /// /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 /// - private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List sequences, bool logits) + private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, IReadOnlyList sequences, bool logits) { unsafe {