diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 2a74bb29..251573fc 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -164,10 +164,8 @@ public class ChatSession /// SessionState object representing session state in-memory public SessionState GetSessionState() { - return new() + return new SessionState(Executor.Context.GetState(), ((StatefulExecutorBase)Executor).GetStateData()) { - ExecutorState = JsonSerializer.Serialize(((StatefulExecutorBase)Executor).GetStateData()), - ContextState = Executor.Context.GetState(), InputTransformPipeline = InputTransformPipeline, OutputTransform = OutputTransform, HistoryTransform = HistoryTransform, @@ -185,34 +183,17 @@ public class ChatSession { if (Executor is StatefulExecutorBase statefulExecutor) { - if (state.ExecutorState is null) - { - statefulExecutor.ResetState(); - } - else - { - statefulExecutor.LoadState( - JsonSerializer.Deserialize( - state.ExecutorState, statefulExecutor.GetStateData().GetType() - ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) - ); - } - } - else - { - if (state.ExecutorState is not null) - { - throw new ArgumentException("Executor does not support state", nameof(state)); - } - } - if (state.ContextState is null) - { - Executor.Context.ResetState(); + statefulExecutor.LoadState( + JsonSerializer.Deserialize( + state.ExecutorState, statefulExecutor.GetStateData().GetType() + ) as ExecutorBaseState ?? throw new ArgumentException("Executor state is invalid", nameof(state)) + ); } else { - Executor.Context.LoadState(state.ContextState); + throw new ArgumentException("Executor must be a StatefulExecutorBase to support loading of session state", nameof(state)); } + Executor.Context.LoadState(state.ContextState); History = ChatHistory.FromJson(state.History) ?? new(); } @@ -612,12 +593,12 @@ public record SessionState /// /// Saved executor state for the session in JSON format. /// - public string? ExecutorState { get; init; } + public string ExecutorState { get; init; } /// /// Saved context state (KV cache) for the session. /// - public State? ContextState { get; init; } + public State ContextState { get; init; } /// /// The input transform pipeline used in this session. @@ -638,4 +619,15 @@ public record SessionState /// The JSON representation of the chat history for this session. /// public string History { get; init; } = new ChatHistory().ToJson(); + + /// + /// Create a new session state. + /// + /// + /// + public SessionState(State contextState, ExecutorBaseState executorState) + { + ContextState = contextState; + ExecutorState = JsonSerializer.Serialize(executorState); + } } \ No newline at end of file diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index dc0508e5..d8b418c3 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -24,7 +24,6 @@ namespace LLama : IDisposable { private readonly ILogger? _logger; - private readonly State _emptyState; /// /// Total number of tokens in vocabulary of this model @@ -76,7 +75,6 @@ namespace LLama @params.ToLlamaContextParams(out var lparams); NativeHandle = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); - _emptyState = GetState(); } /// @@ -216,15 +214,6 @@ namespace LLama } } - /// - /// Reset the context to the empty state. - /// - /// - public void ResetState() - { - NativeApi.llama_kv_cache_clear(NativeHandle); - } - /// /// Sample a single token from this context, using the given sampling pipeline /// diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs index b9a0b412..7bdb8d2b 100644 --- a/LLama/LLamaExecutorBase.cs +++ b/LLama/LLamaExecutorBase.cs @@ -250,14 +250,6 @@ namespace LLama /// public abstract ExecutorBaseState GetStateData(); - - /// - /// Resets the executor to its initial state. - /// Note: Does not affect the context and KV cache. - /// - /// - public abstract void ResetState(); - /// /// Load the state from data. /// diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 13e8f067..9476976e 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -110,22 +110,6 @@ namespace LLama } } - /// - public override void ResetState() - { - _n_session_consumed = 0; - _embed_inps = new List(); - _is_prompt_run = true; - _consumedTokensCount = 0; - _embeds = new List(); - _last_n_tokens = new FixedSizeQueue((int) Context.ContextSize); - _n_matching_session_tokens = 0; - _pastTokensCount = 0; - _pathSession = null; - _session_tokens = new List(); - MirostatMu = 0; - } - /// protected override Task GetLoopCondition(InferStateArgs args) { diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 9c6ae954..79f1b8cc 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -93,22 +93,6 @@ namespace LLama } } - /// - public override void ResetState() - { - _n_session_consumed = 0; - _embed_inps = new List(); - _is_prompt_run = true; - _consumedTokensCount = 0; - _embeds = new List(); - _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize); - _n_matching_session_tokens = 0; - _pastTokensCount = 0; - _pathSession = null; - _session_tokens = new List(); - MirostatMu = 0; - } - /// /// Define whether to continue the loop to generate responses. ///