diff --git a/LLama.Web/Common/ParameterOptions.cs b/LLama.Web/Common/ParameterOptions.cs
index 3cdd3701..7677f04a 100644
--- a/LLama.Web/Common/ParameterOptions.cs
+++ b/LLama.Web/Common/ParameterOptions.cs
@@ -1,9 +1,99 @@
using LLama.Common;
+using LLama.Abstractions;
namespace LLama.Web.Common
{
- public class ParameterOptions : InferenceParams
- {
+ public class ParameterOptions : IInferenceParams
+ {
public string Name { get; set; }
- }
+
+
+
+ ///
+ /// number of tokens to keep from initial prompt
+ ///
+ public int TokensKeep { get; set; } = 0;
+ ///
+ /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
+ /// until it complete.
+ ///
+ public int MaxTokens { get; set; } = -1;
+ ///
+ /// logit bias for specific tokens
+ ///
+ public Dictionary? LogitBias { get; set; } = null;
+
+ ///
+ /// Sequences where the model will stop generating further tokens.
+ ///
+ public IEnumerable AntiPrompts { get; set; } = Array.Empty();
+ ///
+ /// path to file for saving/loading model eval state
+ ///
+ public string PathSession { get; set; } = string.Empty;
+ ///
+ /// string to suffix user inputs with
+ ///
+ public string InputSuffix { get; set; } = string.Empty;
+ ///
+ /// string to prefix user inputs with
+ ///
+ public string InputPrefix { get; set; } = string.Empty;
+ ///
+ /// 0 or lower to use vocab size
+ ///
+ public int TopK { get; set; } = 40;
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TopP { get; set; } = 0.95f;
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TfsZ { get; set; } = 1.0f;
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TypicalP { get; set; } = 1.0f;
+ ///
+ /// 1.0 = disabled
+ ///
+ public float Temperature { get; set; } = 0.8f;
+ ///
+ /// 1.0 = disabled
+ ///
+ public float RepeatPenalty { get; set; } = 1.1f;
+ ///
+ /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
+ ///
+ public int RepeatLastTokensCount { get; set; } = 64;
+ ///
+ /// frequency penalty coefficient
+ /// 0.0 = disabled
+ ///
+ public float FrequencyPenalty { get; set; } = .0f;
+ ///
+ /// presence penalty coefficient
+ /// 0.0 = disabled
+ ///
+ public float PresencePenalty { get; set; } = .0f;
+ ///
+ /// Mirostat uses tokens instead of words.
+ /// algorithm described in the paper https://arxiv.org/abs/2007.14966.
+ /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ ///
+ public MirostatType Mirostat { get; set; } = MirostatType.Disable;
+ ///
+ /// target entropy
+ ///
+ public float MirostatTau { get; set; } = 5.0f;
+ ///
+ /// learning rate
+ ///
+ public float MirostatEta { get; set; } = 0.1f;
+ ///
+ /// consider newlines as a repeatable token (penalize_nl)
+ ///
+ public bool PenalizeNL { get; set; } = true;
+ }
}
diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
new file mode 100644
index 00000000..73cbbfd2
--- /dev/null
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -0,0 +1,117 @@
+using System.Collections.Generic;
+using LLama.Common;
+
+namespace LLama.Abstractions
+{
+ ///
+ /// The paramters used for inference.
+ ///
+ public interface IInferenceParams
+ {
+ ///
+ /// number of tokens to keep from initial prompt
+ ///
+ public int TokensKeep { get; set; }
+
+ ///
+ /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
+ /// until it complete.
+ ///
+ public int MaxTokens { get; set; }
+
+ ///
+ /// logit bias for specific tokens
+ ///
+ public Dictionary? LogitBias { get; set; }
+
+
+ ///
+ /// Sequences where the model will stop generating further tokens.
+ ///
+ public IEnumerable AntiPrompts { get; set; }
+
+ ///
+ /// path to file for saving/loading model eval state
+ ///
+ public string PathSession { get; set; }
+
+ ///
+ /// string to suffix user inputs with
+ ///
+ public string InputSuffix { get; set; }
+
+ ///
+ /// string to prefix user inputs with
+ ///
+ public string InputPrefix { get; set; }
+
+ ///
+ /// 0 or lower to use vocab size
+ ///
+ public int TopK { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TopP { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TfsZ { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TypicalP { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float Temperature { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float RepeatPenalty { get; set; }
+
+ ///
+ /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
+ ///
+ public int RepeatLastTokensCount { get; set; }
+
+ ///
+ /// frequency penalty coefficient
+ /// 0.0 = disabled
+ ///
+ public float FrequencyPenalty { get; set; }
+
+ ///
+ /// presence penalty coefficient
+ /// 0.0 = disabled
+ ///
+ public float PresencePenalty { get; set; }
+
+ ///
+ /// Mirostat uses tokens instead of words.
+ /// algorithm described in the paper https://arxiv.org/abs/2007.14966.
+ /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
+ ///
+ public MirostatType Mirostat { get; set; }
+
+ ///
+ /// target entropy
+ ///
+ public float MirostatTau { get; set; }
+
+ ///
+ /// learning rate
+ ///
+ public float MirostatEta { get; set; }
+
+ ///
+ /// consider newlines as a repeatable token (penalize_nl)
+ ///
+ public bool PenalizeNL { get; set; }
+ }
+}
\ No newline at end of file
diff --git a/LLama/Abstractions/ILLamaExecutor.cs b/LLama/Abstractions/ILLamaExecutor.cs
index d35e075e..6a750895 100644
--- a/LLama/Abstractions/ILLamaExecutor.cs
+++ b/LLama/Abstractions/ILLamaExecutor.cs
@@ -23,7 +23,7 @@ namespace LLama.Abstractions
/// Any additional parameters
/// A cancellation token.
///
- IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken token = default);
+ IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
///
/// Asynchronously infers a response from the model.
@@ -32,6 +32,6 @@ namespace LLama.Abstractions
/// Any additional parameters
/// A cancellation token.
///
- IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, CancellationToken token = default);
+ IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, CancellationToken token = default);
}
}
diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs
index b87e8984..4a4544b0 100644
--- a/LLama/ChatSession.cs
+++ b/LLama/ChatSession.cs
@@ -138,7 +138,7 @@ namespace LLama
///
///
///
- public IEnumerable Chat(ChatHistory history, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
+ public IEnumerable Chat(ChatHistory history, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
@@ -159,7 +159,7 @@ namespace LLama
///
///
///
- public IEnumerable Chat(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
+ public IEnumerable Chat(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
foreach(var inputTransform in InputTransformPipeline)
{
@@ -182,7 +182,7 @@ namespace LLama
///
///
///
- public async IAsyncEnumerable ChatAsync(ChatHistory history, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable ChatAsync(ChatHistory history, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var prompt = HistoryTransform.HistoryToText(history);
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.User, prompt).Messages);
@@ -202,7 +202,7 @@ namespace LLama
///
///
///
- public async IAsyncEnumerable ChatAsync(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var inputTransform in InputTransformPipeline)
{
@@ -218,13 +218,13 @@ namespace LLama
History.Messages.AddRange(HistoryTransform.TextToHistory(AuthorRole.Assistant, sb.ToString()).Messages);
}
- private IEnumerable ChatInternal(string prompt, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
+ private IEnumerable ChatInternal(string prompt, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
var results = _executor.Infer(prompt, inferenceParams, cancellationToken);
return OutputTransform.Transform(results);
}
- private async IAsyncEnumerable ChatAsyncInternal(string prompt, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ private async IAsyncEnumerable ChatAsyncInternal(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var results = _executor.InferAsync(prompt, inferenceParams, cancellationToken);
await foreach (var item in OutputTransform.TransformAsync(results))
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index 77af7eaf..001a8f8e 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -1,4 +1,5 @@
-using System;
+using LLama.Abstractions;
+using System;
using System.Collections.Generic;
namespace LLama.Common
@@ -7,7 +8,7 @@ namespace LLama.Common
///
/// The paramters used for inference.
///
- public class InferenceParams
+ public class InferenceParams : IInferenceParams
{
///
/// number of tokens to keep from initial prompt
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index afbc0f25..dbd1b593 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -231,13 +231,13 @@ namespace LLama
///
///
///
- protected abstract bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs);
+ protected abstract bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs);
///
/// The core inference logic.
///
///
///
- protected abstract void InferInternal(InferenceParams inferenceParams, InferStateArgs args);
+ protected abstract void InferInternal(IInferenceParams inferenceParams, InferStateArgs args);
///
/// Save the current state to a file.
///
@@ -267,7 +267,7 @@ namespace LLama
///
///
///
- public virtual IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
+ public virtual IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
if (inferenceParams is null)
@@ -324,7 +324,7 @@ namespace LLama
///
///
///
- public virtual async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public virtual async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 89fbac59..e055c147 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -1,4 +1,5 @@
-using LLama.Common;
+using LLama.Abstractions;
+using LLama.Common;
using LLama.Native;
using System;
using System.Collections.Generic;
@@ -136,7 +137,7 @@ namespace LLama
}
}
///
- protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs)
+ protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
@@ -179,7 +180,7 @@ namespace LLama
return false;
}
///
- protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args)
+ protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index bc3a242e..f5c1583e 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -1,5 +1,6 @@
using LLama.Common;
using LLama.Native;
+using LLama.Abstractions;
using System;
using System.Collections.Generic;
using System.IO;
@@ -122,7 +123,7 @@ namespace LLama
///
///
///
- protected override bool PostProcess(InferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs)
+ protected override bool PostProcess(IInferenceParams inferenceParams, InferStateArgs args, out IEnumerable? extraOutputs)
{
extraOutputs = null;
if (_embed_inps.Count <= _consumedTokensCount)
@@ -166,7 +167,7 @@ namespace LLama
}
///
- protected override void InferInternal(InferenceParams inferenceParams, InferStateArgs args)
+ protected override void InferInternal(IInferenceParams inferenceParams, InferStateArgs args)
{
if (_embeds.Count > 0)
{
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 88fa1695..dd0497c9 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -36,7 +36,7 @@ namespace LLama
}
///
- public IEnumerable Infer(string text, InferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
+ public IEnumerable Infer(string text, IInferenceParams? inferenceParams = null, CancellationToken cancellationToken = default)
{
cancellationToken.ThrowIfCancellationRequested();
int n_past = 1;
@@ -123,7 +123,7 @@ namespace LLama
}
///
- public async IAsyncEnumerable InferAsync(string text, InferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
foreach (var result in Infer(text, inferenceParams, cancellationToken))
{