diff --git a/LLama.Examples/NewVersion/LoadAndSaveSession.cs b/LLama.Examples/NewVersion/LoadAndSaveSession.cs index 33774b13..9e6116ce 100644 --- a/LLama.Examples/NewVersion/LoadAndSaveSession.cs +++ b/LLama.Examples/NewVersion/LoadAndSaveSession.cs @@ -8,7 +8,7 @@ namespace LLama.Examples.NewVersion { Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); - var prompt = File.ReadAllText("Assets/chat-with-bob.txt").Trim(); + var prompt = (await File.ReadAllTextAsync("Assets/chat-with-bob.txt")).Trim(); var parameters = new ModelParams(modelPath) { @@ -50,7 +50,7 @@ namespace LLama.Examples.NewVersion Console.ForegroundColor = ConsoleColor.White; ex.Context.Dispose(); - ex = new(new LLamaContext(parameters)); + ex = new(new LLamaContext(model, parameters)); session = new ChatSession(ex); session.LoadSession(statePath); diff --git a/LLama.Examples/NewVersion/SemanticKernelChat.cs b/LLama.Examples/NewVersion/SemanticKernelChat.cs index 9bdbcfec..9fd59058 100644 --- a/LLama.Examples/NewVersion/SemanticKernelChat.cs +++ b/LLama.Examples/NewVersion/SemanticKernelChat.cs @@ -1,13 +1,7 @@ -using System.Reflection.Metadata; -using System.Security.Cryptography; -using System.Text; -using LLama.Abstractions; +using System.Security.Cryptography; using LLama.Common; -using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.AI.ChatCompletion; -using Microsoft.SemanticKernel.AI.TextCompletion; using LLamaSharp.SemanticKernel.ChatCompletion; -using LLamaSharp.SemanticKernel.TextCompletion; namespace LLama.Examples.NewVersion { @@ -22,7 +16,7 @@ namespace LLama.Examples.NewVersion // Load weights into memory var parameters = new ModelParams(modelPath) { - Seed = RandomNumberGenerator.GetInt32(int.MaxValue), + Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)), }; using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); diff --git a/LLama.Examples/NewVersion/SemanticKernelMemory.cs b/LLama.Examples/NewVersion/SemanticKernelMemory.cs index 316b3611..1b15b5ad 100644 --- a/LLama.Examples/NewVersion/SemanticKernelMemory.cs +++ b/LLama.Examples/NewVersion/SemanticKernelMemory.cs @@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion Console.Write("Please input your model path: "); var modelPath = Console.ReadLine(); - var seed = 1337; + var seed = 1337u; // Load weights into memory var parameters = new ModelParams(modelPath) { diff --git a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs index 93a68773..7764d982 100644 --- a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs +++ b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs @@ -18,7 +18,7 @@ namespace LLama.Examples.NewVersion // Load weights into memory var parameters = new ModelParams(modelPath) { - Seed = RandomNumberGenerator.GetInt32(int.MaxValue), + Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) }; using var model = LLamaWeights.LoadFromFile(parameters); var ex = new StatelessExecutor(model, parameters); diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 309d5654..4c412c93 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -15,7 +15,7 @@ namespace LLama.Examples.NewVersion // Load weights into memory var @params = new ModelParams(modelPath) { - Seed = RandomNumberGenerator.GetInt32(int.MaxValue) + Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) }; using var weights = LLamaWeights.LoadFromFile(@params); diff --git a/LLama.Examples/Program.cs b/LLama.Examples/Program.cs index e455e9e5..d9c8a890 100644 --- a/LLama.Examples/Program.cs +++ b/LLama.Examples/Program.cs @@ -1,4 +1,5 @@ using LLama.Examples.NewVersion; +using LLama.Native; Console.WriteLine("======================================================================================================"); @@ -7,7 +8,7 @@ Console.WriteLine(" __ __ ____ _ Console.WriteLine("======================================================================================================"); - +NativeApi.llama_empty_call(); Console.WriteLine(); await NewVersionTestRunner.Run(); \ No newline at end of file diff --git a/LLama.Unittest/BasicTest.cs b/LLama.Unittest/BasicTest.cs index 0142e0d9..2cd1806f 100644 --- a/LLama.Unittest/BasicTest.cs +++ b/LLama.Unittest/BasicTest.cs @@ -27,36 +27,8 @@ namespace LLama.Unittest public void BasicModelProperties() { Assert.Equal(32000, _model.VocabCount); - Assert.Equal(2048, _model.ContextSize); + Assert.Equal(4096, _model.ContextSize); Assert.Equal(4096, _model.EmbeddingSize); - Assert.Equal(Encoding.UTF8, _model.Encoding); - } - - [Fact] - public void CloneContext() - { - var original = _model.CreateContext(_params); - - // Evaluate something (doesn't matter what, as long as it begins with token 1) - original.Eval(new[] { 1, 42, 321 }, 0); - - // Clone current state - var clone = original.Clone(); - - // Now evaluate something more - var reply1a = original.Eval(new[] { 4, 5, 6 }, 3); - var reply2a = original.Eval(new[] { 7, 8, 9 }, 6); - - // Assert that the context replied differently each time - Assert.NotEqual(reply1a, reply2a); - - // Give the same prompts to the cloned state - var reply1b = clone.Eval(new[] { 4, 5, 6 }, 3); - var reply2b = clone.Eval(new[] { 7, 8, 9 }, 6); - - // Assert that the cloned context replied in the same way as originally - Assert.Equal(reply1a, reply1b); - Assert.Equal(reply2a, reply2b); } } } \ No newline at end of file diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs index 198511f1..2edf3a62 100644 --- a/LLama.Unittest/LLamaContextTests.cs +++ b/LLama.Unittest/LLamaContextTests.cs @@ -2,7 +2,7 @@ namespace LLama.Unittest { - public class LLamaContextTests + public sealed class LLamaContextTests : IDisposable { private readonly LLamaWeights _weights; @@ -30,7 +30,6 @@ namespace LLama.Unittest Assert.Equal(768, _context.ContextSize); Assert.Equal(4096, _context.EmbeddingSize); Assert.Equal(32000, _context.VocabCount); - Assert.Equal(0, _context.KVCacheTokenCount); } [Fact] diff --git a/LLama.Unittest/ModelsParamsTests.cs b/LLama.Unittest/ModelsParamsTests.cs index 413bda83..000f5853 100644 --- a/LLama.Unittest/ModelsParamsTests.cs +++ b/LLama.Unittest/ModelsParamsTests.cs @@ -13,7 +13,6 @@ namespace LLama.Unittest { BatchSize = 17, ContextSize = 42, - LoraAdapter = "adapter", Seed = 42, GpuLayerCount = 111 }; @@ -31,9 +30,13 @@ namespace LLama.Unittest { BatchSize = 17, ContextSize = 42, - LoraAdapter = "adapter", Seed = 42, - GpuLayerCount = 111 + GpuLayerCount = 111, + LoraAdapters = + { + new("abc", 1), + new("def", 0) + } }; var settings = new Newtonsoft.Json.JsonSerializerSettings(); diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 47ad3392..19f618af 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -16,7 +16,7 @@ namespace LLama.Unittest _params = new ModelParams(Constants.ModelPath) { ContextSize = 60, - Seed = 1754 + Seed = 1754, }; _weights = LLamaWeights.LoadFromFile(_params); } @@ -48,13 +48,13 @@ namespace LLama.Unittest { var executor = new StatelessExecutor(_weights, _params); - const string question = " Question. why is a cat the best pet?\nAnswer: "; + const string question = " Question. cats or dogs?\nAnswer: "; // The context size is set to 60. Generate more than that, forcing it to generate a coherent response // with a modified context var @params = new InferenceParams() { - MaxTokens = 100, + MaxTokens = 65, TokensKeep = question.Length, }; diff --git a/LLama.Unittest/TokenTests.cs b/LLama.Unittest/TokenTests.cs index 7ec484b8..a699b9b8 100644 --- a/LLama.Unittest/TokenTests.cs +++ b/LLama.Unittest/TokenTests.cs @@ -27,7 +27,7 @@ public sealed class TokenTests [Fact] public void TokensEndWith() { - var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString(new[] { @@ -41,7 +41,7 @@ public sealed class TokenTests [Fact] public void TokensEndSubstring() { - var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString((IList)new[] { @@ -53,7 +53,7 @@ public sealed class TokenTests [Fact] public void TokensNotEndWith() { - var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString((IList)new[] { @@ -67,7 +67,7 @@ public sealed class TokenTests [Fact] public void TokensNotEndWithNothing() { - var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, Encoding.UTF8); + var tokens = _model.NativeHandle.Tokenize("The cat sat on the edge of the mat", false, true, Encoding.UTF8); var result = tokens.TokensEndsWithAnyString((IList)Array.Empty(), _model.NativeHandle, Encoding.UTF8); Assert.False(result); diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index f06757e3..2829b99e 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -4,7 +4,7 @@ using LLama.Abstractions; namespace LLama.Web.Common { public class ModelOptions - : IModelParams + : ILLamaParams { public string Name { get; set; } @@ -14,7 +14,7 @@ namespace LLama.Web.Common /// /// Model context size (n_ctx) /// - public int ContextSize { get; set; } = 512; + public uint ContextSize { get; set; } = 512; /// /// the GPU that is used for scratch and small tensors /// @@ -30,7 +30,7 @@ namespace LLama.Web.Common /// /// Seed for the random number generator (seed) /// - public int Seed { get; set; } = 1686349486; + public uint Seed { get; set; } = 1686349486; /// /// Use f16 instead of f32 for memory kv (memory_f16) /// @@ -51,26 +51,31 @@ namespace LLama.Web.Common /// Model path (model) /// public string ModelPath { get; set; } + /// - /// model alias - /// - public string ModelAlias { get; set; } = "unknown"; - /// - /// lora adapter path (lora_adapter) - /// - public string LoraAdapter { get; set; } = string.Empty; - /// - /// base model path for the lora adapter (lora_base) - /// - public string LoraBase { get; set; } = string.Empty; - /// - /// Number of threads (-1 = autodetect) (n_threads) + /// List of LoRAs to apply /// - public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); + public AdapterCollection LoraAdapters { get; set; } = new(); + + /// + /// base model path for the lora adapter (lora_base) + /// + public string LoraBase { get; set; } = string.Empty; + /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// Number of threads (null = autodetect) (n_threads) /// - public int BatchSize { get; set; } = 512; + public uint? Threads { get; set; } + + /// + /// Number of threads to use for batch processing (null = autodetect) (n_threads) + /// + public uint? BatchThreads { get; set; } + + /// + /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// + public uint BatchSize { get; set; } = 512; /// /// Whether to convert eos to newline during the inference. @@ -107,5 +112,10 @@ namespace LLama.Web.Common /// The encoding to use for models /// public Encoding Encoding { get; set; } = Encoding.UTF8; + + /// + /// Load vocab only (no weights) + /// + public bool VocabOnly { get; set; } } } diff --git a/LLama.Web/Services/ConnectionSessionService.cs b/LLama.Web/Services/ConnectionSessionService.cs index 7dfcde39..b5867d9b 100644 --- a/LLama.Web/Services/ConnectionSessionService.cs +++ b/LLama.Web/Services/ConnectionSessionService.cs @@ -3,7 +3,6 @@ using LLama.Web.Common; using LLama.Web.Models; using Microsoft.Extensions.Options; using System.Collections.Concurrent; -using System.Drawing; namespace LLama.Web.Services { @@ -50,15 +49,16 @@ namespace LLama.Web.Services if (modelOption.MaxInstances > -1 && currentInstances >= modelOption.MaxInstances) return Task.FromResult(ServiceResult.FromError("Maximum model instances reached")); - // Create model - var llamaModel = new LLamaContext(modelOption); + // Load weights + // todo: it would be better to have a central service which loads weights and shares them between all contexts that need them! + using var weights = LLamaWeights.LoadFromFile(modelOption); // Create executor ILLamaExecutor executor = executorType switch { - LLamaExecutorType.Interactive => new InteractiveExecutor(llamaModel), - LLamaExecutorType.Instruct => new InstructExecutor(llamaModel), - LLamaExecutorType.Stateless => new StatelessExecutor(llamaModel), + LLamaExecutorType.Interactive => new InteractiveExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext + LLamaExecutorType.Instruct => new InstructExecutor(new LLamaContext(weights, modelOption)), //todo: properly dispose of LLamaContext + LLamaExecutorType.Stateless => new StatelessExecutor(weights, modelOption), _ => default }; diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ab542694..f1eb3538 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -16,10 +16,15 @@ public class StatefulChatService : IDisposable public StatefulChatService(IConfiguration configuration) { - _context = new LLamaContext(new Common.ModelParams(configuration["ModelPath"]) + var @params = new Common.ModelParams(configuration["ModelPath"]) { - ContextSize = 512 - }); + ContextSize = 512, + }; + + // todo: share weights from a central service + using var weights = LLamaWeights.LoadFromFile(@params); + + _context = new LLamaContext(weights, @params); _session = new ChatSession(new InteractiveExecutor(_context)); } diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index b924f4d8..71da775f 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -12,10 +12,16 @@ namespace LLama.WebAPI.Services public StatelessChatService(IConfiguration configuration) { - _context = new LLamaContext(new ModelParams(configuration["ModelPath"]) + var @params = new Common.ModelParams(configuration["ModelPath"]) { ContextSize = 512, - }); + }; + + // todo: share weights from a central service + using var weights = LLamaWeights.LoadFromFile(@params); + + _context = new LLamaContext(weights, @params); + // TODO: replace with a stateless executor _session = new ChatSession(new InteractiveExecutor(_context)) .WithOutputTransform(new LLamaTransforms.KeywordTextOutputStreamTransform(new string[] { "User:", "Assistant:" }, redundancyLength: 8)) diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs new file mode 100644 index 00000000..201a9b9a --- /dev/null +++ b/LLama/Abstractions/IContextParams.cs @@ -0,0 +1,70 @@ +using System.Text; + +namespace LLama.Abstractions; + +/// +/// The parameters for initializing a LLama context from a model. +/// +public interface IContextParams +{ + /// + /// Model context size (n_ctx) + /// + uint ContextSize { get; set; } + + /// + /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) + /// + uint BatchSize { get; set; } + + /// + /// Seed for the random number generator (seed) + /// + uint Seed { get; set; } + + /// + /// Use f16 instead of f32 for memory kv (memory_f16) + /// + bool UseFp16Memory { get; set; } + + /// + /// Compute perplexity over the prompt (perplexity) + /// + bool Perplexity { get; set; } + + /// + /// Whether to use embedding mode. (embedding) Note that if this is set to true, + /// The LLamaModel won't produce text response anymore. + /// + bool EmbeddingMode { get; set; } + + /// + /// RoPE base frequency + /// + float RopeFrequencyBase { get; set; } + + /// + /// RoPE frequency scaling factor + /// + float RopeFrequencyScale { get; set; } + + /// + /// Use experimental mul_mat_q kernels + /// + bool MulMatQ { get; set; } + + /// + /// The encoding to use for models + /// + Encoding Encoding { get; set; } + + /// + /// Number of threads (null = autodetect) (n_threads) + /// + uint? Threads { get; set; } + + /// + /// Number of threads to use for batch processing (null = autodetect) (n_threads) + /// + uint? BatchThreads { get; set; } +} \ No newline at end of file diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index 08856ce3..93a9b52b 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -36,7 +36,7 @@ namespace LLama.Abstractions /// public int TopK { get; set; } - /// + /// llama_eval /// 1.0 = disabled /// public float TopP { get; set; } diff --git a/LLama/Abstractions/ILLamaParams.cs b/LLama/Abstractions/ILLamaParams.cs new file mode 100644 index 00000000..636ba199 --- /dev/null +++ b/LLama/Abstractions/ILLamaParams.cs @@ -0,0 +1,11 @@ +namespace LLama.Abstractions +{ + /// + /// Convenience interface for implementing both type of parameters. + /// + /// Mostly exists for backwards compatibility reasons, when these two were not split. + public interface ILLamaParams + : IModelParams, IContextParams + { + } +} diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index ad0608d7..8b4e8497 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -1,4 +1,6 @@ -using System.Text; +using System; +using System.Collections.Generic; +using System.Linq; namespace LLama.Abstractions { @@ -7,36 +9,16 @@ namespace LLama.Abstractions /// public interface IModelParams { - /// - /// Model context size (n_ctx) - /// - int ContextSize { get; set; } - /// /// the GPU that is used for scratch and small tensors /// int MainGpu { get; set; } - /// - /// if true, reduce VRAM usage at the cost of performance - /// - bool LowVram { get; set; } - /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// int GpuLayerCount { get; set; } - /// - /// Seed for the random number generator (seed) - /// - int Seed { get; set; } - - /// - /// Use f16 instead of f32 for memory kv (memory_f16) - /// - bool UseFp16Memory { get; set; } - /// /// Use mmap for faster loads (use_mmap) /// @@ -47,41 +29,15 @@ namespace LLama.Abstractions /// bool UseMemoryLock { get; set; } - /// - /// Compute perplexity over the prompt (perplexity) - /// - bool Perplexity { get; set; } - /// /// Model path (model) /// string ModelPath { get; set; } - /// - /// lora adapter path (lora_adapter) - /// - string LoraAdapter { get; set; } - - /// - /// base model path for the lora adapter (lora_base) - /// - string LoraBase { get; set; } - /// /// Number of threads (-1 = autodetect) (n_threads) /// - int Threads { get; set; } - - /// - /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) - /// - int BatchSize { get; set; } - - /// - /// Whether to use embedding mode. (embedding) Note that if this is set to true, - /// The LLamaModel won't produce text response anymore. - /// - bool EmbeddingMode { get; set; } + uint? Threads { get; set; } /// /// how split tensors should be distributed across GPUs @@ -89,23 +45,62 @@ namespace LLama.Abstractions float[]? TensorSplits { get; set; } /// - /// RoPE base frequency + /// Load vocab only (no weights) /// - float RopeFrequencyBase { get; set; } + bool VocabOnly { get; set; } /// - /// RoPE frequency scaling factor + /// List of LoRA adapters to apply /// - float RopeFrequencyScale { get; set; } + AdapterCollection LoraAdapters { get; } /// - /// Use experimental mul_mat_q kernels + /// base model path for the lora adapter (lora_base) /// - bool MulMatQ { get; set; } + string LoraBase { get; set; } + } - /// - /// The encoding to use for models - /// - Encoding Encoding { get; set; } + /// + /// A LoRA adapter to apply to a model + /// + /// Path to the LoRA file + /// Strength of this LoRA + public readonly record struct LoraAdapter(string Path, float Scale); + + /// + /// A list of LoraAdapter objects + /// + public sealed class AdapterCollection + : List, IEquatable + { + /// + public bool Equals(AdapterCollection? other) + { + if (other == null) + return false; + + return this.SequenceEqual(other); + } + + /// + public override bool Equals(object? obj) + { + return Equals(obj as AdapterCollection); + } + + /// + public override int GetHashCode() + { + unchecked + { + var hash = 17; + for (var i = 0; i < Count; i++) + { + hash += this[i].GetHashCode(); + hash *= 7823; + } + return hash; + } + } } } \ No newline at end of file diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 1ce18dd8..ed877853 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -10,20 +10,17 @@ namespace LLama.Common /// The parameters for initializing a LLama model. /// public record ModelParams - : IModelParams + : ILLamaParams { /// /// Model context size (n_ctx) /// - public int ContextSize { get; set; } = 512; + public uint ContextSize { get; set; } = 512; /// /// the GPU that is used for scratch and small tensors /// public int MainGpu { get; set; } = 0; - /// - /// if true, reduce VRAM usage at the cost of performance - /// - public bool LowVram { get; set; } = false; + /// /// Number of layers to run in VRAM / GPU memory (n_gpu_layers) /// @@ -31,7 +28,7 @@ namespace LLama.Common /// /// Seed for the random number generator (seed) /// - public int Seed { get; set; } = 1686349486; + public uint Seed { get; set; } = 1686349486; /// /// Use f16 instead of f32 for memory kv (memory_f16) /// @@ -52,22 +49,31 @@ namespace LLama.Common /// Model path (model) /// public string ModelPath { get; set; } + /// - /// lora adapter path (lora_adapter) + /// List of LoRAs to apply /// - public string LoraAdapter { get; set; } = string.Empty; + public AdapterCollection LoraAdapters { get; set; } = new(); + /// /// base model path for the lora adapter (lora_base) /// public string LoraBase { get; set; } = string.Empty; + /// - /// Number of threads (-1 = autodetect) (n_threads) + /// Number of threads (null = autodetect) (n_threads) /// - public int Threads { get; set; } = Math.Max(Environment.ProcessorCount / 2, 1); + public uint? Threads { get; set; } + + /// + /// Number of threads to use for batch processing (null = autodetect) (n_threads) + /// + public uint? BatchThreads { get; set; } + /// /// batch size for prompt processing (must be >=32 to use BLAS) (n_batch) /// - public int BatchSize { get; set; } = 512; + public uint BatchSize { get; set; } = 512; /// /// Whether to use embedding mode. (embedding) Note that if this is set to true, @@ -95,6 +101,11 @@ namespace LLama.Common /// public bool MulMatQ { get; set; } + /// + /// Load vocab only (no weights) + /// + public bool VocabOnly { get; set; } + /// /// The encoding to use to convert text for the model /// @@ -138,10 +149,10 @@ namespace LLama.Common /// Use experimental mul_mat_q kernels /// The encoding to use to convert text for the model [Obsolete("Use object initializer to set all optional parameters")] - public ModelParams(string modelPath, int contextSize = 512, int gpuLayerCount = 20, - int seed = 1337, bool useFp16Memory = true, + public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = 20, + uint seed = 1337, bool useFp16Memory = true, bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, - string loraAdapter = "", string loraBase = "", int threads = -1, int batchSize = 512, + string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512, bool embeddingMode = false, float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, string encoding = "UTF-8") @@ -154,15 +165,15 @@ namespace LLama.Common UseMemoryLock = useMemoryLock; Perplexity = perplexity; ModelPath = modelPath; - LoraAdapter = loraAdapter; LoraBase = loraBase; - Threads = threads == -1 ? Math.Max(Environment.ProcessorCount / 2, 1) : threads; + Threads = threads < 1 ? null : (uint)threads; BatchSize = batchSize; EmbeddingMode = embeddingMode; RopeFrequencyBase = ropeFrequencyBase; RopeFrequencyScale = ropeFrequencyScale; MulMatQ = mulMatQ; Encoding = Encoding.GetEncoding(encoding); + LoraAdapters.Add(new LoraAdapter(loraAdapter, 1)); } } diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs new file mode 100644 index 00000000..7ca508a2 --- /dev/null +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -0,0 +1,46 @@ +using System; +using System.IO; +using LLama.Abstractions; +using LLama.Native; + +namespace LLama.Extensions +{ + /// + /// Extention methods to the IContextParams interface + /// + public static class IContextParamsExtensions + { + /// + /// Convert the given `IModelParams` into a `LLamaContextParams` + /// + /// + /// + /// + /// + /// + public static void ToLlamaContextParams(this IContextParams @params, out LLamaContextParams result) + { + result = NativeApi.llama_context_default_params(); + result.n_ctx = @params.ContextSize; + result.n_batch = @params.BatchSize; + result.seed = @params.Seed; + result.f16_kv = @params.UseFp16Memory; + result.logits_all = @params.Perplexity; + result.embedding = @params.EmbeddingMode; + result.rope_freq_base = @params.RopeFrequencyBase; + result.rope_freq_scale = @params.RopeFrequencyScale; + result.mul_mat_q = @params.MulMatQ; + + result.n_threads = Threads(@params.Threads); + result.n_threads_batch = Threads(@params.BatchThreads); + } + + private static uint Threads(uint? value) + { + if (value is > 0) + return (uint)value; + + return (uint)Math.Max(Environment.ProcessorCount / 2, 1); + } + } +} diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs index c4cb1c62..56cd7aaa 100644 --- a/LLama/Extensions/IModelParamsExtensions.cs +++ b/LLama/Extensions/IModelParamsExtensions.cs @@ -12,41 +12,30 @@ namespace LLama.Extensions public static class IModelParamsExtensions { /// - /// Convert the given `IModelParams` into a `LLamaContextParams` + /// Convert the given `IModelParams` into a `LLamaModelParams` /// /// /// /// /// /// - public static MemoryHandle ToLlamaContextParams(this IModelParams @params, out LLamaContextParams result) + public static MemoryHandle ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result) { - if (!File.Exists(@params.ModelPath)) - throw new FileNotFoundException($"The model file does not exist: {@params.ModelPath}"); - if (@params.TensorSplits != null && @params.TensorSplits.Length != 1) throw new ArgumentException("Currently multi-gpu support is not supported by both llama.cpp and LLamaSharp."); - result = NativeApi.llama_context_default_params(); - result.n_ctx = @params.ContextSize; - result.n_batch = @params.BatchSize; + result = NativeApi.llama_model_default_params(); + result.main_gpu = @params.MainGpu; result.n_gpu_layers = @params.GpuLayerCount; - result.seed = @params.Seed; - result.f16_kv = @params.UseFp16Memory; - result.use_mmap = @params.UseMemorymap; result.use_mlock = @params.UseMemoryLock; - result.logits_all = @params.Perplexity; - result.embedding = @params.EmbeddingMode; - result.low_vram = @params.LowVram; - result.rope_freq_base = @params.RopeFrequencyBase; - result.rope_freq_scale = @params.RopeFrequencyScale; - result.mul_mat_q = @params.MulMatQ; + result.use_mmap = @params.UseMemorymap; + result.vocab_only = @params.VocabOnly; var pin = @params.TensorSplits.AsMemory().Pin(); unsafe { - result.tensor_split = (nint)pin.Pointer; + result.tensor_split = (float*)pin.Pointer; } return pin; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 2e0340e8..5a9f4893 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -42,14 +42,9 @@ namespace LLama public int EmbeddingSize => _ctx.EmbeddingSize; /// - /// Get the number of tokens in the KV Cache for this context + /// The context params set for this context /// - public int KVCacheTokenCount => _ctx.KVCacheTokenCount; - - /// - /// The model params set for this model. - /// - public IModelParams Params { get; set; } + public IContextParams Params { get; set; } /// /// The native handle, which is used to be passed to the native APIs @@ -62,24 +57,7 @@ namespace LLama /// public Encoding Encoding => _encoding; - /// - /// - /// - /// Model params. - /// The logger. - [Obsolete("Use the LLamaWeights.CreateContext instead")] - public LLamaContext(IModelParams @params, ILogger? logger = null) - { - Params = @params; - - _logger = logger; - _encoding = @params.Encoding; - - _logger?.LogInformation($"[LLamaContext] Initializing LLama model with params: {this.Params}"); - _ctx = Utils.InitLLamaContextFromModelParams(Params); - } - - internal LLamaContext(SafeLLamaContextHandle nativeContext, IModelParams @params, ILogger? logger = null) + internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) { Params = @params; @@ -95,7 +73,7 @@ namespace LLama /// /// /// - public LLamaContext(LLamaWeights model, IModelParams @params, ILogger? logger = null) + public LLamaContext(LLamaWeights model, IContextParams @params, ILogger? logger = null) { if (model.NativeHandle.IsClosed) throw new ObjectDisposedException("Cannot create context, model weights have been disposed"); @@ -105,30 +83,20 @@ namespace LLama _logger = logger; _encoding = @params.Encoding; - using var pin = @params.ToLlamaContextParams(out var lparams); + @params.ToLlamaContextParams(out var lparams); _ctx = SafeLLamaContextHandle.Create(model.NativeHandle, lparams); } - /// - /// Create a copy of the current state of this context - /// - /// - public LLamaContext Clone() - { - using var pin = Params.ToLlamaContextParams(out var lparams); - var clone = _ctx.Clone(lparams); - return new LLamaContext(clone, Params); - } - /// /// Tokenize a string. /// /// /// Whether to add a bos to the text. + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// - public llama_token[] Tokenize(string text, bool addBos = true) + public llama_token[] Tokenize(string text, bool addBos = true, bool special = false) { - return _ctx.Tokenize(text, addBos, _encoding); + return _ctx.Tokenize(text, addBos, special, _encoding); } /// @@ -177,19 +145,6 @@ namespace LLama fileStream.SetLength(writtenBytes); } - /// - /// Get the state data as a byte array. - /// - /// - [Obsolete("Use `GetState` instead, this supports larger states (over 2GB)")] - public byte[] GetStateData() - { - var stateSize = NativeApi.llama_get_state_size(_ctx); - byte[] stateMemory = new byte[stateSize]; - NativeApi.llama_copy_state_data(_ctx, stateMemory); - return stateMemory; - } - /// /// Get the state data as an opaque handle /// @@ -198,31 +153,28 @@ namespace LLama { var stateSize = _ctx.GetStateSize(); - unsafe + // Allocate a chunk of memory large enough to hold the entire state + var memory = Marshal.AllocHGlobal((nint)stateSize); + try { - // Allocate a chunk of memory large enough to hold the entire state - var memory = Marshal.AllocHGlobal((nint)stateSize); - try - { - // Copy the state data into memory, discover the actual size required - var actualSize = _ctx.GetState(memory, stateSize); + // Copy the state data into memory, discover the actual size required + var actualSize = _ctx.GetState(memory, stateSize); - // Shrink to size - memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); + // Shrink to size + memory = Marshal.ReAllocHGlobal(memory, (nint)actualSize); - // Wrap memory in a "state" - var state = new State(memory); + // Wrap memory in a "state" + var state = new State(memory); - // Set memory to zero, to prevent it being freed in finally block - memory = IntPtr.Zero; + // Set memory to zero, to prevent it being freed in finally block + memory = IntPtr.Zero; - return state; - } - finally - { - if (memory != IntPtr.Zero) - Marshal.FreeHGlobal(memory); - } + return state; + } + finally + { + if (memory != IntPtr.Zero) + Marshal.FreeHGlobal(memory); } } @@ -247,21 +199,6 @@ namespace LLama } } - /// - /// Load the state from memory. - /// - /// - /// - public void LoadState(byte[] stateData) - { - int stateSize = (int)NativeApi.llama_get_state_size(_ctx); - if (stateData.Length > stateSize) - { - throw new RuntimeError("Failed to validate state size."); - } - NativeApi.llama_set_state_data(_ctx, stateData); - } - /// /// Load the state from memory. /// @@ -463,15 +400,15 @@ namespace LLama public int Eval(ReadOnlySpan tokens, int pastTokensCount) { var total = tokens.Length; - for(var i = 0; i < total; i += Params.BatchSize) + for(var i = 0; i < total; i += (int)Params.BatchSize) { var n_eval = total - i; if (n_eval > Params.BatchSize) { - n_eval = Params.BatchSize; + n_eval = (int)Params.BatchSize; } - if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount, Params.Threads)) + if (!_ctx.Eval(tokens.Slice(i, n_eval), pastTokensCount)) { _logger?.LogError($"[LLamaContext] Failed to eval."); throw new RuntimeError("Failed to eval."); diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 64c17539..54ef07b0 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -18,19 +18,22 @@ namespace LLama /// public int EmbeddingSize => _ctx.EmbeddingSize; - /// - /// - /// - /// - public LLamaEmbedder(IModelParams @params) + public LLamaEmbedder(ILLamaParams allParams) + : this(allParams, allParams) { - @params.EmbeddingMode = true; - using var weights = LLamaWeights.LoadFromFile(@params); - _ctx = weights.CreateContext(@params); } - public LLamaEmbedder(LLamaWeights weights, IModelParams @params) + public LLamaEmbedder(IModelParams modelParams, IContextParams contextParams) { + using var weights = LLamaWeights.LoadFromFile(modelParams); + + contextParams.EmbeddingMode = true; + _ctx = weights.CreateContext(contextParams); + } + + public LLamaEmbedder(LLamaWeights weights, IContextParams @params) + { + @params.EmbeddingMode = true; _ctx = weights.CreateContext(@params); } diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 3ff755a0..583b6ca0 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -7,6 +7,7 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using LLama.Extensions; +using LLama.Native; namespace LLama { @@ -20,7 +21,7 @@ namespace LLama : ILLamaExecutor { private readonly LLamaWeights _weights; - private readonly IModelParams _params; + private readonly IContextParams _params; /// /// The context used by the executor when running the inference. @@ -32,7 +33,7 @@ namespace LLama /// /// /// - public StatelessExecutor(LLamaWeights weights, IModelParams @params) + public StatelessExecutor(LLamaWeights weights, IContextParams @params) { _weights = weights; _params = @params; @@ -41,20 +42,6 @@ namespace LLama Context.Dispose(); } - /// - /// Create a new stateless executor which will use the model used to create the given context - /// - /// - [Obsolete("Use the constructor which automatically creates contexts using the LLamaWeights")] - public StatelessExecutor(LLamaContext context) - { - _weights = new LLamaWeights(context.NativeHandle.ModelHandle, context.Params.Encoding); - _params = context.Params; - - Context = _weights.CreateContext(_params); - Context.Dispose(); - } - /// public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { @@ -114,15 +101,16 @@ namespace LLama break; // when run out of context - // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L433 - if (n_past + tokens.Count > Context.ContextSize) + // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 + if (n_past + tokens.Count >= Context.ContextSize) { - var n_left = n_past - inferenceParams.TokensKeep; + var n_left = n_past - inferenceParams.TokensKeep - 1; + var n_discard = n_left / 2; - n_past = Math.Max(1, inferenceParams.TokensKeep); + NativeApi.llama_kv_cache_seq_rm(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1, inferenceParams.TokensKeep + n_discard + 1); + NativeApi.llama_kv_cache_seq_shift(Context.NativeHandle, (LLamaSeqId)0, inferenceParams.TokensKeep + 1 + n_discard, n_past, -n_discard); - tokens.Clear(); - tokens.AddRange(lastTokens.Skip(lastTokens.Count - n_left / 2).Take(n_left / 2)); + n_past -= n_discard; } // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 1b067f1b..bcc41afb 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,5 +1,4 @@ using System; -using System.Text; using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -20,11 +19,6 @@ namespace LLama /// Be careful how you use this! public SafeLlamaModelHandle NativeHandle => _weights; - /// - /// Encoding to use to convert text into bytes for the model - /// - public Encoding Encoding { get; } - /// /// Total number of tokens in vocabulary of this model /// @@ -35,15 +29,24 @@ namespace LLama /// public int ContextSize => NativeHandle.ContextSize; + /// + /// Get the size of this model in bytes + /// + public ulong SizeInBytes => NativeHandle.SizeInBytes; + + /// + /// Get the number of parameters in this model + /// + public ulong ParameterCount => NativeHandle.ParameterCount; + /// /// Dimension of embedding vectors /// public int EmbeddingSize => NativeHandle.EmbeddingSize; - internal LLamaWeights(SafeLlamaModelHandle weights, Encoding encoding) + internal LLamaWeights(SafeLlamaModelHandle weights) { _weights = weights; - Encoding = encoding; } /// @@ -53,13 +56,20 @@ namespace LLama /// public static LLamaWeights LoadFromFile(IModelParams @params) { - using var pin = @params.ToLlamaContextParams(out var lparams); + using var pin = @params.ToLlamaModelParams(out var lparams); var weights = SafeLlamaModelHandle.LoadFromFile(@params.ModelPath, lparams); - if (!string.IsNullOrEmpty(@params.LoraAdapter)) - weights.ApplyLoraFromFile(@params.LoraAdapter, @params.LoraBase, @params.Threads); + foreach (var adapter in @params.LoraAdapters) + { + if (string.IsNullOrEmpty(adapter.Path)) + continue; + if (adapter.Scale <= 0) + continue; + + weights.ApplyLoraFromFile(adapter.Path, adapter.Scale, @params.LoraBase, @params.Threads); + } - return new LLamaWeights(weights, @params.Encoding); + return new LLamaWeights(weights); } /// @@ -73,7 +83,7 @@ namespace LLama /// /// /// - public LLamaContext CreateContext(IModelParams @params) + public LLamaContext CreateContext(IContextParams @params) { return new LLamaContext(this, @params); } diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs new file mode 100644 index 00000000..fd72ddcd --- /dev/null +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -0,0 +1,106 @@ +using System; + +namespace LLama.Native; + +using llama_token = Int32; + +public sealed class LLamaBatchSafeHandle + : SafeLLamaHandleBase +{ + private readonly int _embd; + public LLamaNativeBatch Batch { get; private set; } + + /// + /// the token ids of the input (used when embd is NULL) + /// + public Span Token + { + get + { + unsafe + { + if (_embd != 0) + return new Span(null, 0); + else + return new Span(Batch.token, Batch.n_tokens); + } + } + } + + /// + /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) + /// + public Span Embed + { + get + { + unsafe + { + // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) + /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + + if (_embd != 0) + return new Span(Batch.embd, Batch.n_tokens * _embd); + else + return new Span(null, 0); + } + } + } + + /// + /// the positions of the respective token in the sequence + /// + public Span Pos + { + get + { + unsafe + { + return new Span(Batch.pos, Batch.n_tokens); + } + } + } + + /// + /// the sequence to which the respective token belongs + /// + public Span Sequence_ID + { + get + { + unsafe + { + return new Span(Batch.seq_id, Batch.n_tokens); + } + } + } + + /// + /// if zero, the logits for the respective token will not be output + /// + public Span Logits + { + get + { + unsafe + { + return new Span(Batch.logits, Batch.n_tokens); + } + } + } + + public LLamaBatchSafeHandle(int n_tokens, int embd) + : base((nint)1) + { + _embd = embd; + Batch = NativeApi.llama_batch_init(n_tokens, embd); + } + + protected override bool ReleaseHandle() + { + NativeApi.llama_batch_free(Batch); + Batch = default; + SetHandle(IntPtr.Zero); + return true; + } +} \ No newline at end of file diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index 200301da..50f30c0a 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -19,32 +19,27 @@ namespace LLama.Native /// /// RNG seed, -1 for random /// - public int seed; + public uint seed; /// /// text context /// - public int n_ctx; + public uint n_ctx; /// /// prompt processing batch size /// - public int n_batch; + public uint n_batch; /// - /// number of layers to store in VRAM + /// number of threads to use for generation /// - public int n_gpu_layers; + public uint n_threads; /// - /// the GPU that is used for scratch and small tensors + /// number of threads to use for batch processing /// - public int main_gpu; - - /// - /// how to split layers across multiple GPUs - /// - public nint tensor_split; + public uint n_threads_batch; /// /// ref: https://github.com/ggerganov/llama.cpp/pull/2054 @@ -58,26 +53,6 @@ namespace LLama.Native /// public float rope_freq_scale; - /// - /// called with a progress value between 0 and 1, pass NULL to disable - /// - public IntPtr progress_callback; - - /// - /// context pointer passed to the progress callback - /// - public IntPtr progress_callback_user_data; - - /// - /// if true, reduce VRAM usage at the cost of performance - /// - public bool low_vram - { - readonly get => Convert.ToBoolean(_low_vram); - set => _low_vram = Convert.ToSByte(value); - } - private sbyte _low_vram; - /// /// if true, use experimental mul_mat_q kernels /// @@ -108,36 +83,6 @@ namespace LLama.Native } private sbyte _logits_all; - /// - /// only load the vocabulary, no weights - /// - public bool vocab_only - { - readonly get => Convert.ToBoolean(_vocab_only); - set => _vocab_only = Convert.ToSByte(value); - } - private sbyte _vocab_only; - - /// - /// use mmap if possible - /// - public bool use_mmap - { - readonly get => Convert.ToBoolean(_use_mmap); - set => _use_mmap = Convert.ToSByte(value); - } - private sbyte _use_mmap; - - /// - /// force system to keep model in RAM - /// - public bool use_mlock - { - readonly get => Convert.ToBoolean(_use_mlock); - set => _use_mlock = Convert.ToSByte(value); - } - private sbyte _use_mlock; - /// /// embedding mode only /// diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs new file mode 100644 index 00000000..f1f95ced --- /dev/null +++ b/LLama/Native/LLamaModelParams.cs @@ -0,0 +1,67 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native +{ + /// + /// A C# representation of the llama.cpp `llama_model_params` struct + /// + [StructLayout(LayoutKind.Sequential)] + public unsafe struct LLamaModelParams + { + /// + /// // number of layers to store in VRAM + /// + public int n_gpu_layers; + + /// + /// // the GPU that is used for scratch and small tensors + /// + public int main_gpu; + + /// + /// how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) + /// + public float* tensor_split; + + /// + /// called with a progress value between 0 and 1, pass NULL to disable + /// + LlamaProgressCallback progress_callback; + + /// + /// context pointer passed to the progress callback + /// + void* progress_callback_user_data; + + /// + /// only load the vocabulary, no weights + /// + public bool vocab_only + { + readonly get => Convert.ToBoolean(_vocab_only); + set => _vocab_only = Convert.ToSByte(value); + } + private sbyte _vocab_only; + + /// + /// use mmap if possible + /// + public bool use_mmap + { + readonly get => Convert.ToBoolean(_use_mmap); + set => _use_mmap = Convert.ToSByte(value); + } + private sbyte _use_mmap; + + /// + /// force system to keep model in RAM + /// + public bool use_mlock + { + readonly get => Convert.ToBoolean(_use_mlock); + set => _use_mlock = Convert.ToSByte(value); + } + private sbyte _use_mlock; + } +} diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs index 128e30aa..4d0b9e7f 100644 --- a/LLama/Native/LLamaModelQuantizeParams.cs +++ b/LLama/Native/LLamaModelQuantizeParams.cs @@ -36,5 +36,15 @@ namespace LLama.Native set => _quantize_output_tensor = Convert.ToSByte(value); } private sbyte _quantize_output_tensor; + + /// + /// only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored + /// + public bool only_copy + { + get => Convert.ToBoolean(_only_copy); + set => _only_copy = Convert.ToSByte(value); + } + private sbyte _only_copy; } } diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs new file mode 100644 index 00000000..576f8b27 --- /dev/null +++ b/LLama/Native/LLamaNativeBatch.cs @@ -0,0 +1,45 @@ +using System; +using System.Runtime.InteropServices; + +namespace LLama.Native; + +using llama_token = Int32; + +/// +/// Input data for llama_decode +/// A llama_batch object can contain input about one or many sequences +/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +/// +[StructLayout(LayoutKind.Sequential)] +public readonly unsafe struct LLamaNativeBatch +{ + /// + /// The number of items pointed at by pos, seq_id and logits. + /// + public readonly int n_tokens; + + /// + /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created + /// + public readonly llama_token* token; + + /// + /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created + /// + public readonly float* embd; + + /// + /// the positions of the respective token in the sequence + /// + public readonly LLamaPos* pos; + + /// + /// the sequence to which the respective token belongs + /// + public readonly LLamaSeqId* seq_id; + + /// + /// if zero, the logits for the respective token will not be output + /// + public readonly byte* logits; +} \ No newline at end of file diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs new file mode 100644 index 00000000..18dc8294 --- /dev/null +++ b/LLama/Native/LLamaPos.cs @@ -0,0 +1,15 @@ +namespace LLama.Native; + +public record struct LLamaPos +{ + public int Value; + + public LLamaPos(int value) + { + Value = value; + } + + public static explicit operator int(LLamaPos pos) => pos.Value; + + public static implicit operator LLamaPos(int value) => new(value); +} \ No newline at end of file diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs new file mode 100644 index 00000000..d148fe4d --- /dev/null +++ b/LLama/Native/LLamaSeqId.cs @@ -0,0 +1,15 @@ +namespace LLama.Native; + +public record struct LLamaSeqId +{ + public int Value; + + public LLamaSeqId(int value) + { + Value = value; + } + + public static explicit operator int(LLamaSeqId pos) => pos.Value; + + public static explicit operator LLamaSeqId(int value) => new(value); +} \ No newline at end of file diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 300c6495..15380a37 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -2,7 +2,6 @@ using System.Buffers; using System.Runtime.InteropServices; using System.Text; -using LLama.Common; using LLama.Exceptions; #pragma warning disable IDE1006 // Naming Styles @@ -110,6 +109,13 @@ namespace LLama.Native [DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_empty_call(); + /// + /// Create a LLamaModelParams with default values + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern LLamaModelParams llama_model_default_params(); + /// /// Create a LLamaContextParams with default values /// @@ -138,18 +144,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_mlock_supported(); - /// - /// Export a static computation graph for context of 511 and batch size of 1 - /// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these - /// parameters here to keep things simple - /// IMPORTANT: do not use for anything else other than debugging and testing! - /// - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_eval_export(SafeLLamaContextHandle ctx, string fname); - /// /// Various functions for loading a ggml llama model. /// Allocate (almost) all memory needed for the model. @@ -159,7 +153,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams @params); + public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params); /// /// Create a new llama_context with the given model. @@ -192,7 +186,7 @@ namespace LLama.Native /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_free_model(IntPtr model); - + /// /// Apply a LoRA adapter to a loaded model /// path_base_model is the path to a higher quality model to use as a base for @@ -202,19 +196,12 @@ namespace LLama.Native /// /// /// + /// /// /// /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads); - - /// - /// Returns the number of tokens in the KV cache - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_get_kv_cache_token_count(SafeLLamaContextHandle ctx); + public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, float scale, string? path_base_model, int n_threads); /// /// Sets the current rng seed. @@ -222,7 +209,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, int seed); + public static extern void llama_set_rng_seed(SafeLLamaContextHandle ctx, uint seed); /// /// Returns the maximum size in bytes of the state (rng, logits, embedding @@ -243,21 +230,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); - /// - /// Copies the state to the specified destination address. - /// Destination needs to have allocated enough memory (see llama_get_state_size) - /// - /// - /// - /// the number of bytes copied - public static ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte[] dest) - { - fixed (byte* dstPtr = &dest[0]) - { - return llama_copy_state_data(ctx, dstPtr); - } - } - /// /// Set the state reading from the specified address /// @@ -267,20 +239,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); - /// - /// Set the state reading from the specified address - /// - /// - /// - /// the number of bytes read - public static ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte[] src) - { - fixed (byte* srcPtr = &src[0]) - { - return llama_set_state_data(ctx, srcPtr); - } - } - /// /// Load session file /// @@ -313,24 +271,9 @@ namespace LLama.Native /// /// /// - /// /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int n_tokens, int n_past, int n_threads); - - /// - /// Run the llama inference to obtain the logits and probabilities for the next token. - /// tokens + n_tokens is the provided batch of new tokens to process - /// n_past is the number of tokens to use from previous eval calls - /// - /// - /// - /// - /// - /// - /// Returns 0 on success - [DllImport(libraryName, EntryPoint = "llama_eval", CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_eval_with_pointer(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past, int n_threads); + public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); /// /// Convert the provided text into tokens. @@ -341,10 +284,11 @@ namespace LLama.Native /// /// /// + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. /// Returns the number of tokens on success, no more than n_max_tokens. /// Returns a negative number on failure - the number of tokens that would have been returned /// - public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos) + public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) { // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) var byteCount = encoding.GetByteCount(text); @@ -364,7 +308,7 @@ namespace LLama.Native // Do the actual tokenization fixed (byte* arrayPtr = array) fixed (llama_token* tokensPtr = tokens) - return llama_tokenize_native(ctx, arrayPtr, tokensPtr, n_max_tokens, add_bos); + return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); } finally { @@ -372,28 +316,6 @@ namespace LLama.Native } } - /// - /// Convert the provided text into tokens. - /// - /// - /// - /// - /// - /// - /// Returns the number of tokens on success, no more than n_max_tokens. - /// Returns a negative number on failure - the number of tokens that would have been returned - /// - [DllImport(libraryName, EntryPoint = "llama_tokenize", CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_tokenize_native(SafeLLamaContextHandle ctx, byte* text, llama_token* tokens, int n_max_tokens, bool add_bos); - - /// - /// Get the number of tokens in the model vocabulary for this context - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_vocab(SafeLLamaContextHandle ctx); - /// /// Get the size of the context window for the model for this context /// @@ -402,14 +324,6 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern int llama_n_ctx(SafeLLamaContextHandle ctx); - /// - /// Get the dimension of embedding vectors from the model for this context - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_n_embd(SafeLLamaContextHandle ctx); - /// /// Token logits obtained from the last call to llama_eval() /// The logits for the last token are stored in the last row @@ -423,22 +337,21 @@ namespace LLama.Native public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); /// - /// Get the embeddings for the input - /// shape: [n_embd] (1-dimensional) + /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); + public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); /// - /// Token Id -> String. Uses the vocabulary in the provided context + /// Get the embeddings for the input + /// shape: [n_embd] (1-dimensional) /// /// - /// - /// Pointer to a string. + /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_token_to_str(SafeLLamaContextHandle ctx, llama_token token); + public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); /// /// Get the "Beginning of sentence" token @@ -488,7 +401,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_n_vocab(SafeLlamaModelHandle model); + public static extern int llama_n_vocab(SafeLlamaModelHandle model); /// /// Get the size of the context window for the model @@ -496,7 +409,7 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_n_ctx(SafeLlamaModelHandle model); + public static extern int llama_n_ctx_train(SafeLlamaModelHandle model); /// /// Get the dimension of embedding vectors from this model @@ -504,7 +417,23 @@ namespace LLama.Native /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_n_embd(SafeLlamaModelHandle model); + public static extern int llama_n_embd(SafeLlamaModelHandle model); + + /// + /// Get the size of the model in bytes + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern ulong llama_model_size(SafeLlamaModelHandle model); + + /// + /// Get the number of parameters in this model + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern ulong llama_model_n_params(SafeLlamaModelHandle model); /// /// Convert a single token into text @@ -515,21 +444,23 @@ namespace LLama.Native /// size of the buffer /// The length writte, or if the buffer is too small a negative that indicates the length required [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_to_piece_with_model(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); + public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); /// /// Convert text into tokens /// /// /// + /// /// /// /// + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. Does not insert a leading space. /// Returns the number of tokens on success, no more than n_max_tokens. /// Returns a negative number on failure - the number of tokens that would have been returned /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos); + public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); /// /// Register a callback to receive llama log messages @@ -537,5 +468,98 @@ namespace LLama.Native /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern void llama_log_set(LLamaLogCallback logCallback); - } + + /// + /// Remove all tokens data of cells in [c0, c1) + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_tokens_rm(SafeLLamaContextHandle ctx, int c0, int c1); + + /// + /// Removes all tokens that belong to the specified sequence and have positions in [p0, p1) + /// + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1); + + /// + /// Copy all tokens that belong to the specified sequence to another sequence + /// Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + /// + /// + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_seq_cp(SafeLLamaContextHandle ctx, LLamaSeqId src, LLamaSeqId dest, LLamaPos p0, LLamaPos p1); + + /// + /// Removes all tokens that do not belong to the specified sequence + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_seq_keep(SafeLLamaContextHandle ctx, LLamaSeqId seq); + + /// + /// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) + /// If the KV cache is RoPEd, the KV data is updated accordingly + /// + /// + /// + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta); + + /// + /// Allocates a batch of tokens on the heap + /// The batch has to be freed with llama_batch_free() + /// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) + /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token + /// The rest of the llama_batch members are allocated with size n_tokens + /// All members are left uninitialized + /// + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd); + + /// + /// Frees a batch of tokens allocated with llama_batch_init() + /// + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern void llama_batch_free(LLamaNativeBatch batch); + + /// + /// + /// + /// + /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_decode(SafeLLamaContextHandle ctx, LLamaNativeBatch batch); + + /// + /// Set the number of threads used for decoding + /// + /// + /// n_threads is the number of threads used for generation (single token) + /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) + /// + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern int llama_set_n_threads(SafeLLamaContextHandle ctx, uint n_threads, uint n_threads_batch); + } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 26fd011b..88572254 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,6 @@ using System; using System.Buffers; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -21,26 +22,13 @@ namespace LLama.Native /// /// Total number of tokens in the context /// - public int ContextSize => ThrowIfDisposed().ContextSize; + public int ContextSize => NativeApi.llama_n_ctx(this); /// /// Dimension of embedding vectors /// public int EmbeddingSize => ThrowIfDisposed().EmbeddingSize; - /// - /// Get the number of tokens in the KV Cache for this context - /// - public int KVCacheTokenCount - { - get - { - if (IsClosed) - throw new ObjectDisposedException("Cannot use this `SafeLLamaContextHandle` - it has been disposed"); - return NativeApi.llama_get_kv_cache_token_count(this); - } - } - /// /// Get the model which this context is using /// @@ -64,17 +52,20 @@ namespace LLama.Native _model.DangerousAddRef(ref success); if (!success) throw new RuntimeError("Failed to increment model refcount"); + + } /// protected override bool ReleaseHandle() { + NativeApi.llama_free(DangerousGetHandle()); + SetHandle(IntPtr.Zero); + // Decrement refcount on model _model?.DangerousRelease(); _model = null!; - NativeApi.llama_free(handle); - SetHandle(IntPtr.Zero); return true; } @@ -103,46 +94,38 @@ namespace LLama.Native return new(ctx_ptr, model); } + #endregion /// - /// Create a new llama context with a clone of the current llama context state + /// Token logits obtained from the last call to llama_eval() + /// The logits for the last token are stored in the last row + /// Can be mutated in order to change the probabilities of the next token.
+ /// Rows: n_tokens
+ /// Cols: n_vocab ///
- /// /// - public SafeLLamaContextHandle Clone(LLamaContextParams lparams) + public Span GetLogits() { - // Allocate space to read the state of the current context - var stateSize = GetStateSize(); - var stateMemory = Marshal.AllocHGlobal((nint)stateSize); - try - { - // Copy state from this context into memory - GetState(stateMemory, stateSize); - - // Create a new context - var newCtx = Create(ModelHandle, lparams); - - // Copy state into new context - newCtx.SetState(stateMemory); + var model = ThrowIfDisposed(); - return newCtx; - } - finally + unsafe { - Marshal.FreeHGlobal(stateMemory); + var logits = NativeApi.llama_get_logits(this); + return new Span(logits, model.VocabCount); } } - #endregion + #region tokens /// /// Convert the given text into tokens /// /// The text to tokenize /// Whether the "BOS" token should be added /// Encoding to use for the text + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// /// - public int[] Tokenize(string text, bool add_bos, Encoding encoding) + public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { ThrowIfDisposed(); @@ -158,7 +141,7 @@ namespace LLama.Native try { // Do the actual conversion - var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos); + var n = NativeApi.llama_tokenize(this, text, encoding, temporaryArray, count, add_bos, special); if (n < 0) { throw new RuntimeError("Error happened during tokenization. It's possibly caused by wrong encoding. Please try to " + @@ -177,25 +160,6 @@ namespace LLama.Native } } - /// - /// Token logits obtained from the last call to llama_eval() - /// The logits for the last token are stored in the last row - /// Can be mutated in order to change the probabilities of the next token.
- /// Rows: n_tokens
- /// Cols: n_vocab - ///
- /// - public Span GetLogits() - { - var model = ThrowIfDisposed(); - - unsafe - { - var logits = NativeApi.llama_get_logits(this); - return new Span(logits, model.VocabCount); - } - } - /// /// Convert a token into a string /// @@ -228,25 +192,31 @@ namespace LLama.Native { return ThrowIfDisposed().TokenToSpan(token, dest); } + #endregion /// /// Run the llama inference to obtain the logits and probabilities for the next token. /// /// The provided batch of new tokens to process /// the number of tokens to use from previous eval calls - /// /// Returns true on success - public bool Eval(ReadOnlySpan tokens, int n_past, int n_threads) + public bool Eval(ReadOnlySpan tokens, int n_past) { unsafe { fixed (int* pinned = tokens) { - return NativeApi.llama_eval_with_pointer(this, pinned, tokens.Length, n_past, n_threads) == 0; + var ret = NativeApi.llama_eval(this, pinned, tokens.Length, n_past); + return ret == 0; } } } + public int Decode(LLamaBatchSafeHandle batch) + { + return NativeApi.llama_decode(this, batch.Batch); + } + #region state /// /// Get the size of the state, when saved as bytes diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index 615889d5..adf6bd54 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -29,18 +29,30 @@ namespace LLama.Native /// public int EmbeddingSize { get; } + /// + /// Get the size of this model in bytes + /// + public ulong SizeInBytes { get; } + + /// + /// Get the number of parameters in this model + /// + public ulong ParameterCount { get; } + internal SafeLlamaModelHandle(IntPtr handle) : base(handle) { - VocabCount = NativeApi.llama_model_n_vocab(this); - ContextSize = NativeApi.llama_model_n_ctx(this); - EmbeddingSize = NativeApi.llama_model_n_embd(this); + VocabCount = NativeApi.llama_n_vocab(this); + ContextSize = NativeApi.llama_n_ctx_train(this); + EmbeddingSize = NativeApi.llama_n_embd(this); + SizeInBytes = NativeApi.llama_model_size(this); + ParameterCount = NativeApi.llama_model_n_params(this); } /// protected override bool ReleaseHandle() { - NativeApi.llama_free_model(handle); + NativeApi.llama_free_model(DangerousGetHandle()); SetHandle(IntPtr.Zero); return true; } @@ -52,7 +64,7 @@ namespace LLama.Native /// /// /// - public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaContextParams lparams) + public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams) { var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams); if (model_ptr == IntPtr.Zero) @@ -62,21 +74,24 @@ namespace LLama.Native } #region LoRA + /// /// Apply a LoRA adapter to a loaded model /// /// + /// /// A path to a higher quality model to use as a base for the layers modified by the /// adapter. Can be NULL to use the current loaded model. /// /// - public void ApplyLoraFromFile(string lora, string? modelBase = null, int threads = -1) + public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null, uint? threads = null) { var err = NativeApi.llama_model_apply_lora_from_file( this, lora, + scale, string.IsNullOrEmpty(modelBase) ? null : modelBase, - threads + (int?)threads ?? -1 ); if (err != 0) @@ -97,7 +112,7 @@ namespace LLama.Native { fixed (byte* destPtr = dest) { - var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, destPtr, dest.Length); + var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length); return Math.Abs(length); } } @@ -113,7 +128,7 @@ namespace LLama.Native { unsafe { - var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); + var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); if (length == 0) return ""; @@ -121,7 +136,7 @@ namespace LLama.Native fixed (byte* bytePtr = bytes) { - var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); + var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); Debug.Assert(written == bytes.Length); return encoding.GetString(bytePtr, bytes.Length); @@ -139,7 +154,7 @@ namespace LLama.Native { unsafe { - var length = NativeApi.llama_token_to_piece_with_model(this, llama_token, null, 0); + var length = NativeApi.llama_token_to_piece(this, llama_token, null, 0); if (length == 0) return; @@ -147,7 +162,7 @@ namespace LLama.Native fixed (byte* bytePtr = bytes) { // Decode into bytes - var written = NativeApi.llama_token_to_piece_with_model(this, llama_token, bytePtr, bytes.Length); + var written = NativeApi.llama_token_to_piece(this, llama_token, bytePtr, bytes.Length); Debug.Assert(written == bytes.Length); // Decode into chars @@ -256,8 +271,9 @@ namespace LLama.Native /// /// /// + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. /// - public int[] Tokenize(string text, bool add_bos, Encoding encoding) + public int[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) { // Convert string to bytes, adding one extra byte to the end (null terminator) var bytesCount = encoding.GetByteCount(text); @@ -276,13 +292,13 @@ namespace LLama.Native fixed (byte* bytesPtr = &bytes[0]) { // Tokenize once with no output, to get the token count. Output will be negative (indicating that there was insufficient space) - var count = -NativeApi.llama_tokenize_with_model(this, bytesPtr, (int*)IntPtr.Zero, 0, add_bos); + var count = -NativeApi.llama_tokenize(this, bytesPtr, bytesCount, (int*)IntPtr.Zero, 0, add_bos, special); // Tokenize again, this time outputting into an array of exactly the right size var tokens = new int[count]; fixed (int* tokensPtr = &tokens[0]) { - NativeApi.llama_tokenize_with_model(this, bytesPtr, tokensPtr, count, add_bos); + NativeApi.llama_tokenize(this, bytesPtr, bytesCount, tokensPtr, count, add_bos, special); return tokens; } } diff --git a/LLama/Utils.cs b/LLama/Utils.cs deleted file mode 100644 index f3584c81..00000000 --- a/LLama/Utils.cs +++ /dev/null @@ -1,108 +0,0 @@ -using LLama.Abstractions; -using LLama.Native; -using System; -using System.Collections.Generic; -using System.Runtime.InteropServices; -using System.Text; -using LLama.Extensions; - -namespace LLama -{ - using llama_token = Int32; - - /// - /// Assorted llama utilities - /// - public static class Utils - { - [Obsolete("Use LLamaWeights.LoadFromFile and LLamaWeights.CreateContext instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static SafeLLamaContextHandle InitLLamaContextFromModelParams(IModelParams @params) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - using var weights = LLamaWeights.LoadFromFile(@params); - - using (@params.ToLlamaContextParams(out var lparams)) - return SafeLLamaContextHandle.Create(weights.NativeHandle, lparams); - } - - [Obsolete("Use SafeLLamaContextHandle Tokenize method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static IEnumerable Tokenize(SafeLLamaContextHandle ctx, string text, bool add_bos, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - return ctx.Tokenize(text, add_bos, encoding); - } - - [Obsolete("Use SafeLLamaContextHandle GetLogits method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static Span GetLogits(SafeLLamaContextHandle ctx, int length) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - if (length != ctx.VocabCount) - throw new ArgumentException("length must be the VocabSize"); - - return ctx.GetLogits(); - } - - [Obsolete("Use SafeLLamaContextHandle Eval method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static int Eval(SafeLLamaContextHandle ctx, llama_token[] tokens, int startIndex, int n_tokens, int n_past, int n_threads) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - var slice = tokens.AsSpan().Slice(startIndex, n_tokens); - return ctx.Eval(slice, n_past, n_threads) ? 0 : 1; - } - - [Obsolete("Use SafeLLamaContextHandle TokenToString method instead")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static string TokenToString(llama_token token, SafeLLamaContextHandle ctx, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { - return ctx.TokenToString(token, encoding); - } - - [Obsolete("No longer used internally by LlamaSharp")] - #pragma warning disable CS1591 // Missing XML comment for publicly visible type or member - public static string PtrToString(IntPtr ptr, Encoding encoding) - #pragma warning restore CS1591 // Missing XML comment for publicly visible type or member - { -#if NET6_0_OR_GREATER - // ReSharper disable once PossibleUnintendedReferenceComparison - if(encoding == Encoding.UTF8) - { - return Marshal.PtrToStringUTF8(ptr)!; - } - // ReSharper disable once PossibleUnintendedReferenceComparison - else if(encoding == Encoding.Unicode) - { - return Marshal.PtrToStringUni(ptr)!; - } - else - { - return Marshal.PtrToStringAuto(ptr)!; - } -#else - unsafe - { - byte* tp = (byte*)ptr.ToPointer(); - List bytes = new(); - while (true) - { - byte c = *tp++; - if (c == '\0') - { - break; - } - else - { - bytes.Add(c); - } - } - return encoding.GetString(bytes.ToArray()); - } -#endif - } - - } -} diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal index 7b5c21d9..99b9fd7a 100644 --- a/LLama/runtimes/ggml-metal.metal +++ b/LLama/runtimes/ggml-metal.metal @@ -13,8 +13,8 @@ typedef struct { #define QK4_1 32 typedef struct { - half d; // delta - half m; // min + half d; // delta + half m; // min uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; @@ -24,12 +24,59 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; +// general-purpose kernel for addition of two tensors +// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +// cons: not very efficient kernel void kernel_add( - device const float4 * src0, - device const float4 * src1, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig]; + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + + src0_ptr += ntg.x*nb00; + src1_ptr += ntg.x*nb10; + dst_ptr += ntg.x*nb0; + } } // assumption: src1 is a row @@ -38,7 +85,7 @@ kernel void kernel_add_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig % nb]; } @@ -63,18 +110,18 @@ kernel void kernel_mul_row( } kernel void kernel_scale( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } kernel void kernel_silu( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - float x = src0[tpig]; + device const float4 & x = src0[tpig]; dst[tpig] = x / (1.0f + exp(-x)); } @@ -85,14 +132,21 @@ kernel void kernel_relu( dst[tpig] = max(0.0f, src0[tpig]); } +kernel void kernel_sqr( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src0[tpig]; +} + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; kernel void kernel_gelu( - device const float * src0, - device float * dst, + device const float4 * src0, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { - float x = src0[tpig]; + device const float4 & x = src0[tpig]; // BEWARE !!! // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! @@ -107,7 +161,6 @@ kernel void kernel_soft_max( constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, - threadgroup float * buf [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -119,61 +172,67 @@ kernel void kernel_soft_max( device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; // parallel max - buf[tpitg[0]] = -INFINITY; - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); - } - - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]); - } - threadgroup_barrier(mem_flags::mem_threadgroup); + float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) { + lmax = MAX(lmax, psrc0[i00]); } - - //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of - // the loop, and when that is done, buf[0] has the correct (synchronized) value - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} - - //threadgroup_barrier(mem_flags::mem_threadgroup); - - const float max = buf[0]; + const float max = simd_max(lmax); // parallel sum - buf[tpitg[0]] = 0.0f; + float lsum = 0.0f; for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { const float exp_psrc0 = exp(psrc0[i00] - max); - buf[tpitg[0]] += exp_psrc0; + lsum += exp_psrc0; // Remember the result of exp here. exp is expensive, so we really do not // whish to compute it twice. pdst[i00] = exp_psrc0; } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg[0]/2; i > 0; i /= 2) { - if (tpitg[0] < i) { - buf[tpitg[0]] += buf[tpitg[0] + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { + pdst[i00] /= sum; } +} - // broadcast - not needed, see above - //// broadcast - //if (tpitg[0] == 0) { - // buf[0] = buf[0]; - //} +kernel void kernel_soft_max_4( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; - //threadgroup_barrier(mem_flags::mem_threadgroup); + device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - const float sum = buf[0]; + // parallel max + float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY; + for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) { + lmax4 = fmax(lmax4, psrc4[i00]); + } + float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); - for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { - pdst[i00] /= sum; + const float max = simd_max(lmax); + + // parallel sum + float4 lsum4 = 0.0f; + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + const float4 exp_psrc4 = exp(psrc4[i00] - max); + lsum4 += exp_psrc4; + pdst4[i00] = exp_psrc4; + } + float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3]; + + const float sum = simd_sum(lsum); + + for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) { + pdst4[i00] /= sum; } } @@ -192,6 +251,33 @@ kernel void kernel_diag_mask_inf( dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY; } else { dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00]; + } +} + +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01; + const int64_t i01 = i4/(ne00); i4 -= i01*ne00; + const int64_t i00 = i4; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } } } @@ -259,10 +345,11 @@ kernel void kernel_rms_norm( uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - float4 sumf=0; - float all_sum=0; + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + + float4 sumf = 0; + float all_sum = 0; // parallel sum for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { @@ -275,6 +362,7 @@ kernel void kernel_rms_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { @@ -282,7 +370,9 @@ kernel void kernel_rms_norm( } } if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } sum[0] /= ne00; } @@ -297,7 +387,9 @@ kernel void kernel_rms_norm( y[i00] = x[i00] * scale; } if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } } } @@ -337,8 +429,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre } // putting them in the kernel cause a significant performance penalty -#define N_DST 4 // each SIMD group works on 4 rows -#define N_SIMDGROUP 2 // number of SIMD groups in a thread group +#define N_DST 4 // each SIMD group works on 4 rows +#define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not @@ -349,18 +441,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; + const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; + const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; - float yl[16]; // src1 vector cache - float sumf[nr]={0.f}; - const int ix = tiisg/2; - const int il = 8*(tiisg%2); + float yl[16]; // src1 vector cache + float sumf[nr] = {0.f}; + + const int ix = (tiisg/2); + const int il = (tiisg%2)*8; device const float * yb = y + ix * QK4_0 + il; @@ -371,6 +468,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device sumy += yb[i] + yb[i+1]; yl[i+0] = yb[i+ 0]; yl[i+1] = yb[i+ 1]/256.f; + sumy += yb[i+16] + yb[i+17]; yl[i+8] = yb[i+16]/16.f; yl[i+9] = yb[i+17]/4096.f; @@ -386,12 +484,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; } } } -kernel void kernel_mul_mat_q4_0_f32( +kernel void kernel_mul_mv_q4_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -404,12 +502,12 @@ kernel void kernel_mul_mat_q4_0_f32( constant int64_t & ne1[[buffer(16)]], constant uint & gqa[[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } -kernel void kernel_mul_mat_q4_1_f32( +kernel void kernel_mul_mv_q4_1_f32( device const void * src0, device const float * src1, device float * dst, @@ -429,7 +527,7 @@ kernel void kernel_mul_mat_q4_1_f32( #define NB_Q8_0 8 -kernel void kernel_mul_mat_q8_0_f32( +kernel void kernel_mul_mv_q8_0_f32( device const void * src0, device const float * src1, device float * dst, @@ -491,7 +589,9 @@ kernel void kernel_mul_mat_q8_0_f32( } } -kernel void kernel_mul_mat_f16_f32_1row( +#define N_F32_F32 4 + +kernel void kernel_mul_mv_f32_f32( device const char * src0, device const char * src1, device float * dst, @@ -512,6 +612,77 @@ kernel void kernel_mul_mat_f16_f32_1row( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { + const int64_t r0 = tgpig.x; + const int64_t rb = tgpig.y*N_F32_F32; + const int64_t im = tgpig.z; + + device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + if (ne00 < 128) { + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00; i += 32) { + sumf += (float) x[i] * (float) y[i]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } else { + device const float4 * x4 = (device const float4 *)x; + for (int row = 0; row < N_F32_F32; ++row) { + int r1 = rb + row; + if (r1 >= ne11) { + break; + } + + device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); + device const float4 * y4 = (device const float4 *) y; + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } + } +} + +kernel void kernel_mul_mv_f16_f32_1row( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; @@ -545,7 +716,7 @@ kernel void kernel_mul_mat_f16_f32_1row( #define N_F16_F32 4 -kernel void kernel_mul_mat_f16_f32( +kernel void kernel_mul_mv_f16_f32( device const char * src0, device const char * src1, device float * dst, @@ -616,6 +787,49 @@ kernel void kernel_mul_mat_f16_f32( } } +// Assumes row size (ne00) is a multiple of 4 +kernel void kernel_mul_mv_f16_f32_l4( + device const char * src0, + device const char * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]]) { + + const int nrows = ne11; + const int64_t r0 = tgpig.x; + const int64_t im = tgpig.z; + + device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + + for (int r1 = 0; r1 < nrows; ++r1) { + device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); + + float sumf = 0; + for (int i = tiisg; i < ne00/4; i += 32) { + for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; + } + + float all_sum = simd_sum(sumf); + if (tiisg == 0) { + dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + } + } +} + kernel void kernel_alibi_f32( device const float * src0, device float * dst, @@ -635,7 +849,9 @@ kernel void kernel_alibi_f32( constant uint64_t & nb1, constant uint64_t & nb2, constant uint64_t & nb3, - constant float & m0, + constant float & m0, + constant float & m1, + constant int & n_heads_log2_floor, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -651,37 +867,73 @@ kernel void kernel_alibi_f32( const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - float m_k = pow(m0, i2 + 1); + float m_k; + if (i2 < n_heads_log2_floor) { + m_k = pow(m0, i2 + 1); + } else { + m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); + } for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); } } +typedef void (rope_t)( + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, + uint tiitg[[thread_index_in_threadgroup]], + uint3 tptg[[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]); + +template kernel void kernel_rope( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & mode, - constant float & freq_base, - constant float & freq_scale, + device const void * src0, + device const int32_t * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + constant int & n_past, + constant int & n_dims, + constant int & mode, + constant float & freq_base, + constant float & freq_scale, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -691,7 +943,9 @@ kernel void kernel_rope( const bool is_neox = mode & 2; - const int64_t p = ((mode & 1) == 0 ? n_past + i2 : i2); + device const int32_t * pos = src1; + + const int64_t p = pos[i2]; const float theta_0 = freq_scale * (float)p; const float inv_ndims = -1.f/n_dims; @@ -703,11 +957,11 @@ kernel void kernel_rope( const float cos_theta = cos(theta); const float sin_theta = sin(theta); - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - const float x0 = src[0]; - const float x1 = src[1]; + const T x0 = src[0]; + const T x1 = src[1]; dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; @@ -722,8 +976,8 @@ kernel void kernel_rope( const int64_t i0 = ib*n_dims + ic/2; - device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device float * dst_data = (device float *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); const float x0 = src[0]; const float x1 = src[n_dims/2]; @@ -735,6 +989,9 @@ kernel void kernel_rope( } } +template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope; +template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -860,6 +1117,62 @@ kernel void kernel_cpy_f32_f32( } } +kernel void kernel_concat( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + if (i02 < ne02) { + ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0]; + src0_ptr += ntg.x*nb00; + } else { + ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0]; + src1_ptr += ntg.x*nb10; + } + dst_ptr += ntg.x*nb0; + } +} + //============================================ k-quants ====================================================== #ifndef QK_K @@ -952,7 +1265,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) { //====================================== dot products ========================= -kernel void kernel_mul_mat_q2_K_f32( +kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1096,7 +1409,7 @@ kernel void kernel_mul_mat_q2_K_f32( } #if QK_K == 256 -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1125,8 +1438,8 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[32]; - const uint16_t kmask1 = 0x3030; - const uint16_t kmask2 = 0x0f0f; + //const uint16_t kmask1 = 0x3030; + //const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/4; const int ix = tiisg%4; @@ -1246,10 +1559,9 @@ kernel void kernel_mul_mat_q3_K_f32( dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; } } - } #else -kernel void kernel_mul_mat_q3_K_f32( +kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1320,18 +1632,18 @@ kernel void kernel_mul_mat_q3_K_f32( #endif #if QK_K == 256 -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01[[buffer(4)]], - constant int64_t & ne02[[buffer(5)]], - constant int64_t & ne10[[buffer(9)]], - constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne01 [[buffer(4)]], + constant int64_t & ne02 [[buffer(5)]], + constant int64_t & ne10 [[buffer(9)]], + constant int64_t & ne12 [[buffer(11)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & gqa [[buffer(17)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1426,7 +1738,7 @@ kernel void kernel_mul_mat_q4_K_f32( } } #else -kernel void kernel_mul_mat_q4_K_f32( +kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1515,7 +1827,7 @@ kernel void kernel_mul_mat_q4_K_f32( } #endif -kernel void kernel_mul_mat_q5_K_f32( +kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1688,7 +2000,7 @@ kernel void kernel_mul_mat_q5_K_f32( } -kernel void kernel_mul_mat_q6_K_f32( +kernel void kernel_mul_mv_q6_K_f32( device const void * src0, device const float * src1, device float * dst, @@ -1790,6 +2102,15 @@ kernel void kernel_mul_mat_q6_K_f32( //============================= templates and their specializations ============================= +// NOTE: this is not dequantizing - we are simply fitting the template +template +void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { + float4x4 temp = *(((device float4x4 *)src)); + for (int i = 0; i < 16; i++){ + reg[i/4][i%4] = temp[i/4][i%4]; + } +} + template void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) { half4x4 temp = *(((device half4x4 *)src)); @@ -1801,28 +2122,30 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) template void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 1); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = il ? ( -8.h * 16.h) : -8.h; + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float md = -8.h * xb->d; const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; + const ushort mask1 = mask0 << 8; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; + reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md; + reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md; } } template void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { device const uint16_t * qs = ((device const uint16_t *)xb + 2); - const half d = il ? (xb->d / 16.h) : xb->d; - const half m = xb->m; + const float d1 = il ? (xb->d / 16.h) : xb->d; + const float d2 = d1 / 256.f; + const float m = xb->m; const ushort mask0 = il ? 0x00F0 : 0x000F; - const ushort mask1 = il ? 0xF000 : 0x0F00; + const ushort mask1 = mask0 << 8; for (int i=0;i<8;i++) { - reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; - reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; + reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m; + reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m; } } @@ -1858,7 +2181,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg template void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); + const half d_all = xb->d; device const uint8_t * q = (device const uint8_t *)xb->qs; device const uint8_t * h = (device const uint8_t *)xb->hmask; device const int8_t * scales = (device const int8_t *)xb->scales; @@ -1871,16 +2194,18 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg ((il/4)>0 ? 12 : 3); uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; - int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ - (scale_2&kmask2) | ((scale_1&kmask1) << 4); - float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); + int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) + : (scale_2&kmask2) | ((scale_1&kmask1) << 4); + half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h); + const half ml = 4.h * dl; - il = (il/2)%4; - float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); - uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + il = (il/2) & 3; + const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); + const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + dl *= coef; for (int i = 0; i < 16; ++i) { - reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); + reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml); } #else float kcoef = il&1 ? 1.f/16.f : 1.f; @@ -1895,26 +2220,31 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg #endif } +static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) { + return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)} + : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))}; +} + template void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { - device const uint8_t * q = xb->qs; + device const uchar * q = xb->qs; #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); short is = (il/4) * 2; q = q + (il/4) * 32 + 16 * (il&1); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; #else q = q + 16 * (il&1); device const uint8_t * s = xb->scales; device const half2 * dh = (device const half2 *)xb->d; const float2 d = (float2)dh[0]; const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; - const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); + const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4); #endif const ushort mask = il<2 ? 0x0F : 0xF0; for (int i = 0; i < 16; ++i) { @@ -1928,19 +2258,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg device const uint8_t * qh = xb->qh; #if QK_K == 256 - const float d = (float)(xb->d); - const float min = (float)(xb->dmin); short is = (il/4) * 2; q = q + 32 * (il/4) + 16 * (il&1); qh = qh + 16 * (il&1); uint8_t ul = 1 << (il/2); - il = il%4; - const uchar4 sc = get_scale_min_k4(is, xb->scales); - const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; - const float ml = il<2 ? min * sc[1] : min * sc[3]; + il = il & 3; + const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales); + const half d = il < 2 ? xb->d : xb->d / 16.h; + const half min = xb->dmin; + const half dl = d * sc[0]; + const half ml = min * sc[1]; - const ushort mask = il<2 ? 0x0F : 0xF0; - const float qh_val = il<2 ? 16.f : 256.f; + const ushort mask = il<2 ? 0x0F : 0xF0; + const half qh_val = il<2 ? 16.h : 256.h; for (int i = 0; i < 16; ++i) { reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; } @@ -1959,7 +2289,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg template void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { - const float d_all = (float)(xb->d); + const half d_all = xb->d; device const uint8_t * ql = (device const uint8_t *)xb->ql; device const uint8_t * qh = (device const uint8_t *)xb->qh; device const int8_t * scales = (device const int8_t *)xb->scales; @@ -1967,19 +2297,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg #if QK_K == 256 ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2)%4; + half sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; #else ql = ql + 16 * (il&1); - float sc = scales[il]; + half sc = scales[il]; #endif + const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); + const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; + const half coef = il>1 ? 1.f/16.h : 1.h; + const half ml = d_all * sc * 32.h; + const half dl = d_all * sc * coef; for (int i = 0; i < 16; ++i) { - uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); - uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; - const float coef = il>1 ? 1.f/16.f : 1.f; - float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \ - ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef; - reg[i/4][i%4] = d_all * sc * q * coef; + const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2)) + : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4)); + reg[i/4][i%4] = dl * q - ml; } } @@ -2006,7 +2338,7 @@ kernel void kernel_get_rows( } #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A -#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A +#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B #define BLOCK_SIZE_K 32 #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B @@ -2019,35 +2351,40 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template kernel void kernel_mul_mm(device const uchar * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - threadgroup half * sa = ((threadgroup half *)shared_memory); + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); const uint r0 = tgpig.y; const uint r1 = tgpig.x; const uint im = tgpig.z; + // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; - simdgroup_half8x8 ma[4]; + simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; simdgroup_float8x8 c_res[8]; for (int i = 0; i < 8; i++){ @@ -2055,32 +2392,41 @@ kernel void kernel_mul_mm(device const uchar * src0, } short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; ushort offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; - device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ - + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1; + + uint offset0 = im/gqa*nb02; + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * (r1 * BLOCK_SIZE_N + thread_col) + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { - //load data and store to threadgroup memory + // load data and store to threadgroup memory half4x4 temp_a; dequantize_func(x, il, temp_a); threadgroup_barrier(mem_flags::mem_threadgroup); + #pragma unroll(16) for (int i = 0; i < 16; i++) { *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ - + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \ - + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; } - *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \ - = *((device float2x4 *)y); + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + il = (il + 2 < nl) ? il + 2 : il % 2; x = (il < 2) ? x + (2+nl-1)/nl : x; y += BLOCK_SIZE_K; threadgroup_barrier(mem_flags::mem_threadgroup); - //load matrices from threadgroup memory and conduct outer products + + // load matrices from threadgroup memory and conduct outer products threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(4) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { #pragma unroll(4) @@ -2095,6 +2441,7 @@ kernel void kernel_mul_mm(device const uchar * src0, lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); @@ -2103,25 +2450,26 @@ kernel void kernel_mul_mm(device const uchar * src0, } if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0; + device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float *temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; - if (sgitg==0) { + + device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0; + if (sgitg == 0) { for (int i = 0; i < n_rows; i++) { - for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); } } @@ -2138,6 +2486,7 @@ kernel void kernel_mul_mm(device const uchar * src0, typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ constant uint64_t &, constant uint64_t &, uint, uint, uint); +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; @@ -2148,14 +2497,28 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows; -typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ - constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ - constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint); - -template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; -template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; +typedef void (mat_mm_t)( + device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar *, uint3, uint, uint); + +template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll index 6ed31810..e5fc7dad 100644 Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so index 81733cdd..3532fe99 100644 Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll index f1a9fbdc..89f27e24 100644 Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so index 482fe2f2..81b4aa99 100644 Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll index a5f774f8..62d071ec 100644 Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib index 5bb4497d..c2ca7ec8 100755 Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so index e52d6bda..b9ef4c1d 100644 Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ