| @@ -15,9 +15,9 @@ public class BatchedDecoding | |||
| private const int n_parallel = 8; | |||
| private const int n_len = 32; | |||
| private const int top_k = 40; | |||
| private const float top_p = 0.9f; | |||
| private const float temp = 0.4f; | |||
| private const int top_k = 80; | |||
| private const float top_p = 0.8f; | |||
| private const float temp = 0.5f; | |||
| public static async Task Run() | |||
| { | |||
| @@ -57,7 +57,7 @@ public class BatchedDecoding | |||
| // evaluate the initial prompt | |||
| for (var i = 0; i < prompt_tokens.Length; i++) | |||
| llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false); | |||
| llama_batch_add(batch, prompt_tokens[i], i, new List<LLamaSeqId> { (LLamaSeqId)0 }, false); | |||
| Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length); | |||
| // llama_decode will output logits only for the last token of the prompt | |||
| @@ -66,7 +66,7 @@ public class BatchedDecoding | |||
| batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; | |||
| } | |||
| if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("llama_decode failed"); | |||
| return; | |||
| @@ -91,8 +91,8 @@ public class BatchedDecoding | |||
| for (var i = 0; i < n_parallel; i++) | |||
| i_batch.Add(batch.NativeBatch.n_tokens - 1); | |||
| int n_cur = batch.NativeBatch.n_tokens; | |||
| int n_decode = 0; | |||
| var n_cur = batch.NativeBatch.n_tokens; | |||
| var n_decode = 0; | |||
| var streams = new List<int>[n_parallel]; | |||
| for (var i = 0; i < n_parallel; i++) | |||
| @@ -137,7 +137,7 @@ public class BatchedDecoding | |||
| i_batch[i] = batch.NativeBatch.n_tokens; | |||
| // push this new token for next evaluation | |||
| llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true); | |||
| llama_batch_add(batch, new_token_id, n_cur, new List<LLamaSeqId> { (LLamaSeqId)i }, true); | |||
| n_decode++; | |||
| } | |||
| @@ -151,7 +151,7 @@ public class BatchedDecoding | |||
| n_cur++; | |||
| // evaluate the current batch with the transformer model | |||
| if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0) | |||
| if (context.NativeHandle.Decode(batch) != 0) | |||
| { | |||
| await Console.Error.WriteLineAsync("failed to eval"); | |||
| return; | |||
| @@ -179,7 +179,7 @@ public class BatchedDecoding | |||
| /// <summary> | |||
| /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 | |||
| /// </summary> | |||
| private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits) | |||
| private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, IReadOnlyList<LLamaSeqId> sequences, bool logits) | |||
| { | |||
| unsafe | |||
| { | |||