Browse Source

Remove resetting state ops and make SessionState.ExecutorState and SessionState.ContextState no nullable

tags/0.11.0
eublefar 2 years ago
parent
commit
e05d5d4e14
5 changed files with 21 additions and 80 deletions
  1. +21
    -29
      LLama/ChatSession.cs
  2. +0
    -11
      LLama/LLamaContext.cs
  3. +0
    -8
      LLama/LLamaExecutorBase.cs
  4. +0
    -16
      LLama/LLamaInstructExecutor.cs
  5. +0
    -16
      LLama/LLamaInteractExecutor.cs

+ 21
- 29
LLama/ChatSession.cs View File

@@ -164,10 +164,8 @@ public class ChatSession
/// <returns>SessionState object representing session state in-memory</returns>
public SessionState GetSessionState()
{
return new()
return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData())
{
ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()),
ContextState = Executor.Context.GetState(),
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
HistoryTransform = HistoryTransform,
@@ -185,34 +183,17 @@ public class ChatSession
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
if (state.ExecutorState is null)
{
statefulExecutor.ResetState();
}
else
{
statefulExecutor.LoadState(
JsonSerializer.Deserialize(
state.ExecutorState, statefulExecutor.GetStateData().GetType()
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
);
}
}
else
{
if (state.ExecutorState is not null)
{
throw new ArgumentException("Executor does not support state", nameof(state));
}
}
if (state.ContextState is null)
{
Executor.Context.ResetState();
statefulExecutor.LoadState(
JsonSerializer.Deserialize(
state.ExecutorState, statefulExecutor.GetStateData().GetType()
) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
);
}
else
{
Executor.Context.LoadState(state.ContextState);
throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
}
Executor.Context.LoadState(state.ContextState);
History = ChatHistory.FromJson(state.History) ?? new();
}

@@ -612,12 +593,12 @@ public record SessionState
/// <summary>
/// Saved executor state for the session in JSON format.
/// </summary>
public string? ExecutorState { get; init; }
public string ExecutorState { get; init; }

/// <summary>
/// Saved context state (KV cache) for the session.
/// </summary>
public State? ContextState { get; init; }
public State ContextState { get; init; }

/// <summary>
/// The input transform pipeline used in this session.
@@ -638,4 +619,15 @@ public record SessionState
/// The JSON representation of the chat history for this session.
/// </summary>
public string History { get; init; } = new ChatHistory().ToJson();

/// <summary>
/// Create a new session state.
/// </summary>
/// <param name="contextState"></param>
/// <param name="executorState"></param>
public SessionState(State contextState, ExecutorBaseState executorState)
{
ContextState = contextState;
ExecutorState = JsonSerializer.Serialize(executorState);
}
}

+ 0
- 11
LLama/LLamaContext.cs View File

@@ -24,7 +24,6 @@ namespace LLama
: IDisposable
{
private readonly ILogger? _logger;
private readonly State _emptyState;

/// <summary>
/// Total number of tokens in vocabulary of this model
@@ -76,7 +75,6 @@ namespace LLama

@params.ToLlamaContextParams(out var lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
_emptyState = GetState();
}

/// <summary>
@@ -216,15 +214,6 @@ namespace LLama
}
}

/// <summary>
/// Reset the context to the empty state.
/// </summary>
/// <exception cref="RuntimeError"></exception>
public void ResetState()
{
NativeApi.llama_kv_cache_clear(NativeHandle);
}

/// <summary>
/// Sample a single token from this context, using the given sampling pipeline
/// </summary>


+ 0
- 8
LLama/LLamaExecutorBase.cs View File

@@ -250,14 +250,6 @@ namespace LLama
/// <returns></returns>
public abstract ExecutorBaseState GetStateData();

/// <summary>
/// Resets the executor to its initial state.
/// Note: Does not affect the context and KV cache.
/// </summary>
/// <returns></returns>
public abstract void ResetState();

/// <summary>
/// Load the state from data.
/// </summary>


+ 0
- 16
LLama/LLamaInstructExecutor.cs View File

@@ -110,22 +110,6 @@ namespace LLama
}
}

/// <inheritdoc />
public override void ResetState()
{
_n_session_consumed = 0;
_embed_inps = new List<LLamaToken>();
_is_prompt_run = true;
_consumedTokensCount = 0;
_embeds = new List<LLamaToken>();
_last_n_tokens = new FixedSizeQueue<LLamaToken>((int) Context.ContextSize);
_n_matching_session_tokens = 0;
_pastTokensCount = 0;
_pathSession = null;
_session_tokens = new List<LLamaToken>();
MirostatMu = 0;
}

/// <inheritdoc />
protected override Task<bool> GetLoopCondition(InferStateArgs args)
{


+ 0
- 16
LLama/LLamaInteractExecutor.cs View File

@@ -93,22 +93,6 @@ namespace LLama
}
}

/// <inheritdoc />
public override void ResetState()
{
_n_session_consumed = 0;
_embed_inps = new List<LLamaToken>();
_is_prompt_run = true;
_consumedTokensCount = 0;
_embeds = new List<LLamaToken>();
_last_n_tokens = new FixedSizeQueue<LLamaToken>((int)Context.ContextSize);
_n_matching_session_tokens = 0;
_pastTokensCount = 0;
_pathSession = null;
_session_tokens = new List<LLamaToken>();
MirostatMu = 0;
}

/// <summary>
/// Define whether to continue the loop to generate responses.
/// </summary>


Loading…
Cancel
Save