Browse Source

feat: support loading and saving state.

tags/v0.2.3
Yaohui Liu 2 years ago
parent
commit
19979f664a
No known key found for this signature in database GPG Key ID: E86D01E1809BD23E
6 changed files with 42 additions and 13 deletions
  1. +22
    -8
      LLama/LLamaModel.cs
  2. +2
    -2
      LLama/LLamaSharp.csproj
  3. +15
    -0
      LLama/Native/GgmlInitParams.cs
  4. +1
    -1
      LLama/Native/NativeApi.Quantize.cs
  5. +1
    -1
      LLama/Native/NativeApi.Sampling.cs
  6. +1
    -1
      LLama/Native/NativeApi.cs

+ 22
- 8
LLama/LLamaModel.cs View File

@@ -36,9 +36,6 @@ namespace LLama
int _n_session_consumed;
List<llama_token> _embed;

// params related to chat API only
bool _first_time_chat = true;

public string Name { get; set; }
public SafeLLamaContextHandle NativeHandle => _ctx;

@@ -53,11 +50,12 @@ namespace LLama
bool memory_f16 = true, bool random_prompt = false, bool use_color = false, bool interactive = false,
bool embedding = false, bool interactive_first = false, bool prompt_cache_all = false, bool instruct = false, bool penalize_nl = true,
bool perplexity = false, bool use_mmap = true, bool use_mlock = false, bool mem_test = false,
bool verbose_prompt = false) : this(new LLamaParams(seed, n_threads, n_predict, n_parts, n_ctx, n_batch,
bool verbose_prompt = false, string encoding = "UTF-8") : this(new LLamaParams(seed, n_threads, n_predict, n_parts, n_ctx, n_batch,
n_keep, n_gpu_layers, logit_bias, top_k, top_p, tfs_z, typical_p, temp, repeat_penalty, repeat_last_n, frequency_penalty,
presence_penalty, mirostat, mirostat_tau, mirostat_eta, model_path, prompt, path_session, input_prefix,
input_suffix, antiprompt, lora_adapter, lora_base, memory_f16, random_prompt, use_color, interactive, embedding,
interactive_first, prompt_cache_all, instruct, penalize_nl, perplexity, use_mmap, use_mlock, mem_test, verbose_prompt), model_name, echo_input, verbose)
interactive_first, prompt_cache_all, instruct, penalize_nl, perplexity, use_mmap, use_mlock, mem_test, verbose_prompt),
model_name, echo_input, verbose, encoding)
{
}
@@ -293,6 +291,25 @@ namespace LLama
return Call(text, encoding);
}

public void SaveState(string filename)
{
var stateSize = NativeApi.llama_get_state_size(_ctx);
byte[] stateMemory = new byte[stateSize];
NativeApi.llama_copy_state_data(_ctx, stateMemory);
File.WriteAllBytes(filename, stateMemory);
}

public void LoadState(string filename)
{
var stateMemory = File.ReadAllBytes(filename);
if(stateMemory.Length != (int)NativeApi.llama_get_state_size(_ctx))
{
throw new RuntimeError("Failed to validate state size.");
}
NativeApi.llama_set_state_data(_ctx, stateMemory);

}

public IEnumerable<string> Call(string text, string encoding = "UTF-8")
{
_is_antiprompt = false;
@@ -507,9 +524,6 @@ namespace LLama
}
else
{
// Assuming that the necessary variables have been defined and initialized,
// the C# equivalent code could be:

while (_embed_inp.Count > _n_consumed)
{
_embed.Add(_embed_inp[_n_consumed]);


+ 2
- 2
LLama/LLamaSharp.csproj View File

@@ -8,7 +8,7 @@
<Platforms>AnyCPU;x64</Platforms>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>

<Version>0.2.3</Version>
<Version>0.2.4</Version>
<Authors>Yaohui Liu, Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
@@ -21,7 +21,7 @@
The .NET binding of LLama.cpp, providing APIs to run the model and deploy it on Web.
</Description>
<PackageReleaseNotes>
LLama 0.2.3 mainly fixed some BUGs of model inference.
LLama 0.2.4 mainly supports loading and saving session state.
</PackageReleaseNotes>
<PackageLicenseExpression>MIT</PackageLicenseExpression>
<PackageOutputPath>packages</PackageOutputPath>


+ 15
- 0
LLama/Native/GgmlInitParams.cs View File

@@ -0,0 +1,15 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace LLama.Native
{
internal struct GgmlInitParams
{
public ulong mem_size;
public IntPtr mem_buffer;
[MarshalAs(UnmanagedType.I1)]
public bool no_alloc;
}
}

+ 1
- 1
LLama/Native/NativeApi.Quantize.cs View File

@@ -5,7 +5,7 @@ using System.Text;

namespace LLama.Native
{
internal partial class NativeApi
public partial class NativeApi
{
/// <summary>
/// Returns 0 on success


+ 1
- 1
LLama/Native/NativeApi.Sampling.cs View File

@@ -6,7 +6,7 @@ using System.Text;
namespace LLama.Native
{
using llama_token = Int32;
internal unsafe partial class NativeApi
public unsafe partial class NativeApi
{
/// <summary>
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.


+ 1
- 1
LLama/Native/NativeApi.cs View File

@@ -8,7 +8,7 @@ using LLama.Exceptions;
namespace LLama.Native
{
using llama_token = Int32;
internal unsafe partial class NativeApi
public unsafe partial class NativeApi
{
static NativeApi()
{


Loading…
Cancel
Save