Browse Source

Minor cleanup on the BatchedDecoding example

tags/v0.7.0^2
Martin Evans 2 years ago
parent
commit
aae63a5b92
1 changed files with 10 additions and 10 deletions
  1. +10
    -10
      LLama.Examples/NewVersion/BatchedDecoding.cs

+ 10
- 10
LLama.Examples/NewVersion/BatchedDecoding.cs View File

@@ -15,9 +15,9 @@ public class BatchedDecoding
private const int n_parallel = 8;
private const int n_len = 32;

private const int top_k = 40;
private const float top_p = 0.9f;
private const float temp = 0.4f;
private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.5f;

public static async Task Run()
{
@@ -57,7 +57,7 @@ public class BatchedDecoding

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
llama_batch_add(batch, prompt_tokens[i], i, new() { (LLamaSeqId)0 }, false);
llama_batch_add(batch, prompt_tokens[i], i, new List<LLamaSeqId> { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == (int)prompt_tokens.Length);

// llama_decode will output logits only for the last token of the prompt
@@ -66,7 +66,7 @@ public class BatchedDecoding
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}

if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
@@ -91,8 +91,8 @@ public class BatchedDecoding
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);

int n_cur = batch.NativeBatch.n_tokens;
int n_decode = 0;
var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
for (var i = 0; i < n_parallel; i++)
@@ -137,7 +137,7 @@ public class BatchedDecoding
i_batch[i] = batch.NativeBatch.n_tokens;

// push this new token for next evaluation
llama_batch_add(batch, new_token_id, n_cur, new() { (LLamaSeqId)i }, true);
llama_batch_add(batch, new_token_id, n_cur, new List<LLamaSeqId> { (LLamaSeqId)i }, true);

n_decode++;
}
@@ -151,7 +151,7 @@ public class BatchedDecoding
n_cur++;

// evaluate the current batch with the transformer model
if (NativeApi.llama_decode(context.NativeHandle, batch.NativeBatch) != 0)
if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
@@ -179,7 +179,7 @@ public class BatchedDecoding
/// <summary>
/// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
/// </summary>
private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, List<LLamaSeqId> sequences, bool logits)
private static void llama_batch_add(LLamaBatchSafeHandle batchHandle, int token, LLamaPos pos, IReadOnlyList<LLamaSeqId> sequences, bool logits)
{
unsafe
{


Loading…
Cancel
Save