diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index 2a74bb29..251573fc 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -164,10 +164,8 @@ public class ChatSession
/// SessionState object representing session state in-memory
public SessionState GetSessionState()
{
- return new()
+ return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData())
{
- ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()),
- ContextState = Executor.Context.GetState(),
InputTransformPipeline = InputTransformPipeline,
OutputTransform = OutputTransform,
HistoryTransform = HistoryTransform,
@@ -185,34 +183,17 @@ public class ChatSession
{
if (Executor is StatefulExecutorBase statefulExecutor)
{
- if (state.ExecutorState is null)
- {
- statefulExecutor.ResetState();
- }
- else
- {
- statefulExecutor.LoadState(
- JsonSerializer.Deserialize(
- state.ExecutorState, statefulExecutor.GetStateData().GetType()
- ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
- );
- }
- }
- else
- {
- if (state.ExecutorState is not null)
- {
- throw new ArgumentException("Executor does not support state", nameof(state));
- }
- }
- if (state.ContextState is null)
- {
- Executor.Context.ResetState();
+ statefulExecutor.LoadState(
+ JsonSerializer.Deserialize(
+ state.ExecutorState, statefulExecutor.GetStateData().GetType()
+ ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state))
+ );
}
else
{
- Executor.Context.LoadState(state.ContextState);
+ throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state));
}
+ Executor.Context.LoadState(state.ContextState);
History = ChatHistory.FromJson(state.History) ?? new();
}
@@ -612,12 +593,12 @@ public record SessionState
///
/// Saved executor state for the session in JSON format.
///
- public string? ExecutorState { get; init; }
+ public string ExecutorState { get; init; }
///
/// Saved context state (KV cache) for the session.
///
- public State? ContextState { get; init; }
+ public State ContextState { get; init; }
///
/// The input transform pipeline used in this session.
@@ -638,4 +619,15 @@ public record SessionState
/// The JSON representation of the chat history for this session.
///
public string History { get; init; } = new ChatHistory().ToJson();
+
+ ///
+ /// Create a new session state.
+ ///
+ ///
+ ///
+ public SessionState(State contextState, ExecutorBaseState executorState)
+ {
+ ContextState = contextState;
+ ExecutorState = JsonSerializer.Serialize(executorState);
+ }
}
\ No newline at end of file
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index dc0508e5..d8b418c3 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -24,7 +24,6 @@ namespace LLama
: IDisposable
{
private readonly ILogger? _logger;
- private readonly State _emptyState;
///
/// Total number of tokens in vocabulary of this model
@@ -76,7 +75,6 @@ namespace LLama
@params.ToLlamaContextParams(out var lparams);
NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams);
- _emptyState = GetState();
}
///
@@ -216,15 +214,6 @@ namespace LLama
}
}
- ///
- /// Reset the context to the empty state.
- ///
- ///
- public void ResetState()
- {
- NativeApi.llama_kv_cache_clear(NativeHandle);
- }
-
///
/// Sample a single token from this context, using the given sampling pipeline
///
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index b9a0b412..7bdb8d2b 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -250,14 +250,6 @@ namespace LLama
///
public abstract ExecutorBaseState GetStateData();
-
- ///
- /// Resets the executor to its initial state.
- /// Note: Does not affect the context and KV cache.
- ///
- ///
- public abstract void ResetState();
-
///
/// Load the state from data.
///
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 13e8f067..9476976e 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -110,22 +110,6 @@ namespace LLama
}
}
- ///
- public override void ResetState()
- {
- _n_session_consumed = 0;
- _embed_inps = new List();
- _is_prompt_run = true;
- _consumedTokensCount = 0;
- _embeds = new List();
- _last_n_tokens = new FixedSizeQueue((int) Context.ContextSize);
- _n_matching_session_tokens = 0;
- _pastTokensCount = 0;
- _pathSession = null;
- _session_tokens = new List();
- MirostatMu = 0;
- }
-
///
protected override Task GetLoopCondition(InferStateArgs args)
{
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 9c6ae954..79f1b8cc 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -93,22 +93,6 @@ namespace LLama
}
}
- ///
- public override void ResetState()
- {
- _n_session_consumed = 0;
- _embed_inps = new List();
- _is_prompt_run = true;
- _consumedTokensCount = 0;
- _embeds = new List();
- _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize);
- _n_matching_session_tokens = 0;
- _pastTokensCount = 0;
- _pathSession = null;
- _session_tokens = new List();
- MirostatMu = 0;
- }
-
///
/// Define whether to continue the loop to generate responses.
///