diff --git a/LLama.Unittest/LLamaContextTests.cs b/LLama.Unittest/LLamaContextTests.cs
index 345e518d..ab27d988 100644
--- a/LLama.Unittest/LLamaContextTests.cs
+++ b/LLama.Unittest/LLamaContextTests.cs
@@ -28,7 +28,7 @@ namespace LLama.Unittest
[Fact]
public void CheckProperties()
{
- Assert.Equal(768, _context.ContextSize);
+ Assert.Equal(768u, _context.ContextSize);
Assert.Equal(4096, _context.EmbeddingSize);
Assert.Equal(32000, _context.VocabCount);
}
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 7b770b38..e462401a 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -23,6 +23,9 @@ namespace LLama.Web.Common
///
public int MainGpu { get; set; } = 0;
+ ///
+ public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;
+
///
public int GpuLayerCount { get; set; } = 20;
diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs
index 3ef41bec..73e03b99 100644
--- a/LLama/Abstractions/IModelParams.cs
+++ b/LLama/Abstractions/IModelParams.cs
@@ -16,9 +16,28 @@ namespace LLama.Abstractions
public interface IModelParams
{
///
- /// the GPU that is used for scratch and small tensors
+ /// main_gpu interpretation depends on split_mode:
+ ///
+ /// -
+ /// None
+ /// The GPU that is used for the entire mode.
+ ///
+ /// -
+ /// Row
+ /// The GPU that is used for small tensors and intermediate results.
+ ///
+ /// -
+ /// Layer
+ /// Ignored.
+ ///
+ ///
///
- int MainGpu { get; }
+ int MainGpu { get; set; }
+
+ ///
+ /// How to split the model across multiple GPUs
+ ///
+ GPUSplitMode SplitMode { get; }
///
/// Number of layers to run in VRAM / GPU memory (n_gpu_layers)
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index b124b84d..3afee9cb 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -18,6 +18,9 @@ namespace LLama.Common
///
public int MainGpu { get; set; } = 0;
+ ///
+ public GPUSplitMode SplitMode { get; set; } = GPUSplitMode.None;
+
///
public int GpuLayerCount { get; set; } = 20;
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index 21273617..cd3075ab 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -36,6 +36,9 @@ namespace LLama.Extensions
result.yarn_orig_ctx = @params.YarnOriginalContext ?? 0;
result.rope_scaling_type = @params.YarnScalingType ?? RopeScalingType.LLAMA_ROPE_SCALING_UNSPECIFIED;
+ result.cb_eval = IntPtr.Zero;
+ result.cb_eval_user_data = IntPtr.Zero;
+
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = !@params.NoKqvOffload;
diff --git a/LLama/Extensions/IModelParamsExtensions.cs b/LLama/Extensions/IModelParamsExtensions.cs
index f7fadece..69b9e288 100644
--- a/LLama/Extensions/IModelParamsExtensions.cs
+++ b/LLama/Extensions/IModelParamsExtensions.cs
@@ -21,15 +21,16 @@ public static class IModelParamsExtensions
///
public static IDisposable ToLlamaModelParams(this IModelParams @params, out LLamaModelParams result)
{
- if (@params.UseMemoryLock && !NativeApi.llama_mlock_supported())
- throw new NotSupportedException("'UseMemoryLock' is not supported (llama_mlock_supported() == false)");
- if (@params.UseMemorymap && !NativeApi.llama_mmap_supported())
- throw new NotSupportedException("'UseMemorymap' is not supported (llama_mmap_supported() == false)");
+ if (@params.UseMemoryLock && !NativeApi.llama_supports_mlock())
+ throw new NotSupportedException("'UseMemoryLock' is not supported (llama_supports_mlock() == false)");
+ if (@params.UseMemorymap && !NativeApi.llama_supports_mmap())
+ throw new NotSupportedException("'UseMemorymap' is not supported (llama_supports_mmap() == false)");
var disposer = new GroupDisposable();
result = NativeApi.llama_model_default_params();
result.main_gpu = @params.MainGpu;
+ result.split_mode = @params.SplitMode;
result.n_gpu_layers = @params.GpuLayerCount < 0 ? int.MaxValue : @params.GpuLayerCount;
result.use_mlock = @params.UseMemoryLock;
result.use_mmap = @params.UseMemorymap;
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 6d39a8f9..5d026b67 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -33,7 +33,7 @@ namespace LLama
///
/// Total number of tokens in the context
///
- public int ContextSize => NativeHandle.ContextSize;
+ public uint ContextSize => NativeHandle.ContextSize;
///
/// Dimension of embedding vectors
@@ -323,7 +323,7 @@ namespace LLama
var candidates_p = LLamaTokenDataArray.Create(logits);
// Extract most recently returned tokens
- var last_n_repeat = Math.Min(ContextSize, repeatLastTokensCount);
+ var last_n_repeat = Math.Min((int)ContextSize, repeatLastTokensCount);
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
// Apply penalties to candidates
diff --git a/LLama/LLamaExecutorBase.cs b/LLama/LLamaExecutorBase.cs
index 4713166e..3a697507 100644
--- a/LLama/LLamaExecutorBase.cs
+++ b/LLama/LLamaExecutorBase.cs
@@ -83,7 +83,7 @@ namespace LLama
_pastTokensCount = 0;
_consumedTokensCount = 0;
_n_session_consumed = 0;
- _last_n_tokens = new FixedSizeQueue(Context.ContextSize);
+ _last_n_tokens = new FixedSizeQueue((int)Context.ContextSize);
_decoder = new StreamingTokenDecoder(context);
}
@@ -170,7 +170,7 @@ namespace LLama
_pastTokensCount = Math.Max(1, tokensToKeep);
// insert n_left/2 tokens at the start of embed from last_n_tokens
- _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip(Context.ContextSize - n_left / 2 - _embeds.Count));
+ _embeds.InsertRange(0, _last_n_tokens.Take(_last_n_tokens.Count - _embeds.Count).Skip((int)Context.ContextSize - n_left / 2 - _embeds.Count));
// stop saving session if we run out of context
_pathSession = string.Empty;
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 993019f1..969da783 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -200,7 +200,7 @@ namespace LLama
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
- var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
+ var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? (int)Context.ContextSize : inferenceParams.RepeatLastTokensCount;
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 2e72c7ae..7d742c81 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -179,7 +179,7 @@ namespace LLama
if (_embed_inps.Count <= _consumedTokensCount && !args.WaitForInput)
{
- var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
+ var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? (int)Context.ContextSize : inferenceParams.RepeatLastTokensCount;
// optionally save the session on first sample (for faster prompt loading next time)
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
diff --git a/LLama/Native/GPUSplitMode.cs b/LLama/Native/GPUSplitMode.cs
new file mode 100644
index 00000000..96957d0f
--- /dev/null
+++ b/LLama/Native/GPUSplitMode.cs
@@ -0,0 +1,23 @@
+namespace LLama.Native;
+
+///
+///
+///
+/// llama_split_mode
+public enum GPUSplitMode
+{
+ ///
+ /// Single GPU
+ ///
+ None = 0,
+
+ ///
+ /// Split layers and KV across GPUs
+ ///
+ Layer = 1,
+
+ ///
+ /// split rows across GPUs
+ ///
+ Row = 2,
+}
\ No newline at end of file
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index bfd39ea4..118dd540 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -8,7 +8,8 @@ namespace LLama.Native
///
///
///
- public delegate void LlamaProgressCallback(float progress, IntPtr ctx);
+ /// llama_progress_callback
+ public delegate bool LlamaProgressCallback(float progress, IntPtr ctx);
///
/// A C# representation of the llama.cpp `llama_context_params` struct
@@ -46,37 +47,46 @@ namespace LLama.Native
///
public RopeScalingType rope_scaling_type;
-
///
/// RoPE base frequency, 0 = from model
///
- public float rope_freq_base;
+ public float rope_freq_base;
///
/// RoPE frequency scaling factor, 0 = from model
///
- public float rope_freq_scale;
+ public float rope_freq_scale;
///
/// YaRN extrapolation mix factor, negative = from model
///
- public float yarn_ext_factor;
+ public float yarn_ext_factor;
///
/// YaRN magnitude scaling factor
///
- public float yarn_attn_factor;
+ public float yarn_attn_factor;
///
/// YaRN low correction dim
///
- public float yarn_beta_fast;
+ public float yarn_beta_fast;
///
/// YaRN high correction dim
///
- public float yarn_beta_slow;
+ public float yarn_beta_slow;
///
/// YaRN original context size
///
public uint yarn_orig_ctx;
+ ///
+ /// ggml_backend_sched_eval_callback
+ ///
+ public IntPtr cb_eval;
+
+ ///
+ /// User data passed into cb_eval
+ ///
+ public IntPtr cb_eval_user_data;
+
///
/// data type for K cache
///
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 0fa0fbe9..8eb0a8b9 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -106,6 +106,31 @@
/// Benchmark@7B: 5.15GB, +0.0044 ppl
LLAMA_FTYPE_MOSTLY_Q6_K = 18,
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_IQ2_XXS = 19,
+
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_IQ2_XS = 20,
+
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_Q2_K_S = 21,
+
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22,
+
+ ///
+ /// except 1d tensors
+ ///
+ LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23,
+
///
/// File type was not specified
///
diff --git a/LLama/Native/LLamaModelParams.cs b/LLama/Native/LLamaModelParams.cs
index ed7b6043..a7cdd1a2 100644
--- a/LLama/Native/LLamaModelParams.cs
+++ b/LLama/Native/LLamaModelParams.cs
@@ -14,6 +14,11 @@ namespace LLama.Native
///
public int n_gpu_layers;
+ ///
+ /// how to split the model across multiple GPUs
+ ///
+ public GPUSplitMode split_mode;
+
///
/// the GPU that is used for scratch and small tensors
///
@@ -25,7 +30,8 @@ namespace LLama.Native
public float* tensor_split;
///
- /// called with a progress value between 0 and 1, pass NULL to disable
+ /// called with a progress value between 0 and 1, pass NULL to disable. If the provided progress_callback
+ /// returns true, model loading continues. If it returns false, model loading is immediately aborted.
///
public LlamaProgressCallback progress_callback;
diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs
index 39702b5a..34c1a974 100644
--- a/LLama/Native/LLamaModelQuantizeParams.cs
+++ b/LLama/Native/LLamaModelQuantizeParams.cs
@@ -6,6 +6,7 @@ namespace LLama.Native
///
/// Quantizer parameters used in the native API
///
+ /// llama_model_quantize_params
[StructLayout(LayoutKind.Sequential)]
public struct LLamaModelQuantizeParams
{
@@ -58,5 +59,10 @@ namespace LLama.Native
set => _pure = Convert.ToSByte(value);
}
private sbyte _pure;
+
+ ///
+ /// pointer to importance matrix data
+ ///
+ public IntPtr imatrix;
}
}
diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs
index b849e38d..1c4909bf 100644
--- a/LLama/Native/NativeApi.Quantize.cs
+++ b/LLama/Native/NativeApi.Quantize.cs
@@ -10,9 +10,8 @@ namespace LLama.Native
///
///
///
- /// not great API - very likely to change
/// Returns 0 on success
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern unsafe int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param);
+ public static extern unsafe uint llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param);
}
}
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index 7128441e..a52edc66 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -27,11 +27,12 @@ namespace LLama.Native
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
///
///
- /// A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
- /// A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
+ /// Logits extracted from the original generation context.
+ /// Logits extracted from a separate context from the same model.
+ /// Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
/// Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_classifier_free_guidance(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, SafeLLamaContextHandle guidance_ctx, float scale);
+ public static extern unsafe void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, float* logits, float* logits_guidance, float scale);
///
/// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
@@ -92,6 +93,17 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float p, ulong min_keep);
+ ///
+ /// Dynamic temperature implementation described in the paper https://arxiv.org/abs/2309.02772.
+ ///
+ ///
+ /// Pointer to LLamaTokenDataArray
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_sample_typical(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, float min_temp, float max_temp, float exponent_val);
+
///
/// Modify logits by temperature
///
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index bb28e7ab..c953cb23 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -23,7 +23,7 @@ namespace LLama.Native
///
public static void llama_empty_call()
{
- llama_mmap_supported();
+ llama_max_devices();
}
///
@@ -31,7 +31,7 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_max_devices();
+ public static extern long llama_max_devices();
///
/// Create a LLamaModelParams with default values
@@ -59,14 +59,21 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_mmap_supported();
+ public static extern bool llama_supports_mmap();
///
- /// Check if memory lockingis supported
+ /// Check if memory locking is supported
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern bool llama_mlock_supported();
+ public static extern bool llama_supports_mlock();
+
+ ///
+ /// Check if GPU offload is supported
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern bool llama_supports_gpu_offload();
///
/// Initialize the llama + ggml backend
@@ -163,7 +170,10 @@ namespace LLama.Native
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_n_ctx(SafeLLamaContextHandle ctx);
+ public static extern uint llama_n_ctx(SafeLLamaContextHandle ctx);
+
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern uint llama_n_batch(SafeLLamaContextHandle ctx);
///
/// Token logits obtained from the last call to llama_eval()
@@ -380,6 +390,20 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_kv_cache_seq_shift(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, LLamaPos delta);
+ ///
+ /// Integer division of the positions by factor of `d > 1`
+ /// If the KV cache is RoPEd, the KV data is updated accordingly
+ /// p0 < 0 : [0, p1]
+ /// p1 < 0 : [p0, inf)
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern void llama_kv_cache_seq_div(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1, int d);
+
///
/// Allocates a batch of tokens on the heap
/// Each token can be assigned up to n_seq_max sequence ids
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 2c5d8288..d90d46d5 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -21,7 +21,7 @@ namespace LLama.Native
///
/// Total number of tokens in the context
///
- public int ContextSize => NativeApi.llama_n_ctx(this);
+ public uint ContextSize => NativeApi.llama_n_ctx(this);
///
/// Dimension of embedding vectors
diff --git a/LLama/runtimes/deps/avx/libllama.so b/LLama/runtimes/deps/avx/libllama.so
index 4b788e62..49a80191 100644
Binary files a/LLama/runtimes/deps/avx/libllama.so and b/LLama/runtimes/deps/avx/libllama.so differ
diff --git a/LLama/runtimes/deps/avx/llama.dll b/LLama/runtimes/deps/avx/llama.dll
index 954bb194..e9924272 100644
Binary files a/LLama/runtimes/deps/avx/llama.dll and b/LLama/runtimes/deps/avx/llama.dll differ
diff --git a/LLama/runtimes/deps/avx2/libllama.so b/LLama/runtimes/deps/avx2/libllama.so
index c299ee65..ffa59a49 100644
Binary files a/LLama/runtimes/deps/avx2/libllama.so and b/LLama/runtimes/deps/avx2/libllama.so differ
diff --git a/LLama/runtimes/deps/avx2/llama.dll b/LLama/runtimes/deps/avx2/llama.dll
index 8a0e86c7..996ee0a1 100644
Binary files a/LLama/runtimes/deps/avx2/llama.dll and b/LLama/runtimes/deps/avx2/llama.dll differ
diff --git a/LLama/runtimes/deps/avx512/libllama.so b/LLama/runtimes/deps/avx512/libllama.so
index e9290e66..baee01e6 100644
Binary files a/LLama/runtimes/deps/avx512/libllama.so and b/LLama/runtimes/deps/avx512/libllama.so differ
diff --git a/LLama/runtimes/deps/avx512/llama.dll b/LLama/runtimes/deps/avx512/llama.dll
index 709faf9a..754d5cfb 100644
Binary files a/LLama/runtimes/deps/avx512/llama.dll and b/LLama/runtimes/deps/avx512/llama.dll differ
diff --git a/LLama/runtimes/deps/clblast/clblast.dll b/LLama/runtimes/deps/clblast/clblast.dll
new file mode 100644
index 00000000..4f2a065c
Binary files /dev/null and b/LLama/runtimes/deps/clblast/clblast.dll differ
diff --git a/LLama/runtimes/deps/clblast/libllama.so b/LLama/runtimes/deps/clblast/libllama.so
new file mode 100644
index 00000000..bf2ca9e1
Binary files /dev/null and b/LLama/runtimes/deps/clblast/libllama.so differ
diff --git a/LLama/runtimes/deps/clblast/llama.dll b/LLama/runtimes/deps/clblast/llama.dll
new file mode 100644
index 00000000..55d4a851
Binary files /dev/null and b/LLama/runtimes/deps/clblast/llama.dll differ
diff --git a/LLama/runtimes/deps/cu11.7.1/libllama.so b/LLama/runtimes/deps/cu11.7.1/libllama.so
index 9bce0d51..e6eee3cd 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/libllama.so and b/LLama/runtimes/deps/cu11.7.1/libllama.so differ
diff --git a/LLama/runtimes/deps/cu11.7.1/llama.dll b/LLama/runtimes/deps/cu11.7.1/llama.dll
index 4440d33e..e75a353e 100644
Binary files a/LLama/runtimes/deps/cu11.7.1/llama.dll and b/LLama/runtimes/deps/cu11.7.1/llama.dll differ
diff --git a/LLama/runtimes/deps/cu12.1.0/libllama.so b/LLama/runtimes/deps/cu12.1.0/libllama.so
index 8b579ed2..dbc7d066 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/libllama.so and b/LLama/runtimes/deps/cu12.1.0/libllama.so differ
diff --git a/LLama/runtimes/deps/cu12.1.0/llama.dll b/LLama/runtimes/deps/cu12.1.0/llama.dll
index cab4b10b..88ea37f8 100644
Binary files a/LLama/runtimes/deps/cu12.1.0/llama.dll and b/LLama/runtimes/deps/cu12.1.0/llama.dll differ
diff --git a/LLama/runtimes/deps/libllama.so b/LLama/runtimes/deps/libllama.so
index 670555d1..b9f6a819 100644
Binary files a/LLama/runtimes/deps/libllama.so and b/LLama/runtimes/deps/libllama.so differ
diff --git a/LLama/runtimes/deps/llama.dll b/LLama/runtimes/deps/llama.dll
index 2aa3afdc..9325dadf 100644
Binary files a/LLama/runtimes/deps/llama.dll and b/LLama/runtimes/deps/llama.dll differ
diff --git a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
index 773fac12..efed6ad4 100644
--- a/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
+++ b/LLama/runtimes/deps/osx-arm64/ggml-metal.metal
@@ -59,26 +59,27 @@ kernel void kernel_add(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
- constant int64_t & nb00,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & nb03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
- constant int64_t & nb13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant int64_t & nb0,
- constant int64_t & nb1,
- constant int64_t & nb2,
- constant int64_t & nb3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int64_t & offs,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@@ -90,9 +91,9 @@ kernel void kernel_add(
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
const int i10 = i0 % ne10;
@@ -108,26 +109,26 @@ kernel void kernel_mul(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
- constant int64_t & nb00,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & nb03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
- constant int64_t & nb13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant int64_t & nb0,
- constant int64_t & nb1,
- constant int64_t & nb2,
- constant int64_t & nb3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@@ -157,26 +158,26 @@ kernel void kernel_div(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
- constant int64_t & nb00,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & nb03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
- constant int64_t & nb13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant int64_t & nb0,
- constant int64_t & nb1,
- constant int64_t & nb2,
- constant int64_t & nb3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
@@ -204,7 +205,7 @@ kernel void kernel_add_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
@@ -213,7 +214,7 @@ kernel void kernel_mul_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
@@ -222,7 +223,7 @@ kernel void kernel_div_row(
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
@@ -243,19 +244,53 @@ kernel void kernel_scale_4(
dst[tpig] = src0[tpig] * scale;
}
-kernel void kernel_silu(
- device const float4 * src0,
- device float4 * dst,
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
+ dst[tpig] = max(0.0f, src0[tpig]);
}
-kernel void kernel_relu(
+kernel void kernel_tanh(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_sqr(
@@ -272,26 +307,26 @@ kernel void kernel_sum_rows(
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne03,
- constant int64_t & nb00,
- constant int64_t & nb01,
- constant int64_t & nb02,
- constant int64_t & nb03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
- constant int64_t & nb13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
constant int64_t & ne0,
constant int64_t & ne1,
constant int64_t & ne2,
constant int64_t & ne3,
- constant int64_t & nb0,
- constant int64_t & nb1,
- constant int64_t & nb2,
- constant int64_t & nb3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
uint3 tpig[[thread_position_in_grid]]) {
int64_t i3 = tpig.z;
int64_t i2 = tpig.y;
@@ -313,22 +348,6 @@ kernel void kernel_sum_rows(
dst_row[0] = row_sum;
}
-constant float GELU_COEF_A = 0.044715f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
@@ -650,6 +669,94 @@ kernel void kernel_rms_norm(
}
}
+kernel void kernel_group_norm(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int32_t & n_groups,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t ne = ne00*ne01*ne02;
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+ int start = tgpig * gs;
+ int end = start + gs;
+
+ start += tpitg;
+
+ if (end >= ne) {
+ end = ne;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += ntg) {
+ tmp += src0[j];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float mean = tmp / gs;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += ntg) {
+ float xi = src0[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float variance = tmp / gs;
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int j = start; j < end; j += ntg) {
+ dst[j] *= scale;
+ }
+}
+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
@@ -739,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
//Note: This is a template, but strictly speaking it only applies to
// quantizations where the block size is 32. It also does not
-// giard against the number of rows not being divisible by
+// guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template
void mul_vec_q_n_f32_impl(
@@ -813,14 +920,21 @@ kernel void kernel_mul_mv_q4_0_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -832,14 +946,21 @@ kernel void kernel_mul_mv_q4_1_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -851,14 +972,21 @@ kernel void kernel_mul_mv_q5_0_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -870,14 +998,21 @@ kernel void kernel_mul_mv_q5_1_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -964,12 +1099,19 @@ kernel void kernel_mul_mv_q8_0_f32(
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
constant int64_t & ne10,
+ constant int64_t & ne11,
constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -1075,8 +1217,8 @@ kernel void kernel_mul_mv_f32_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1102,8 +1244,8 @@ kernel void kernel_mul_mv_f16_f16(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -1239,8 +1381,8 @@ kernel void kernel_mul_mv_f16_f32_1row(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1345,8 +1487,8 @@ kernel void kernel_mul_mv_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
@@ -1371,8 +1513,8 @@ kernel void kernel_mul_mv_f16_f32_l4(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
@@ -1436,7 +1578,8 @@ kernel void kernel_alibi_f32(
const int64_t i3 = n / (ne2*ne1*ne0);
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
const int64_t k = i3*ne3 + i2;
float m_k;
@@ -1595,8 +1738,9 @@ kernel void kernel_rope(
dst_data[1] = x0*sin_theta + x1*cos_theta;
}
} else {
- for (int64_t ib = 0; ib < ne0/n_dims; ++ib) {
- for (int64_t ic = 2*tiitg; ic < n_dims; ic += 2*tptg.x) {
+ for (int64_t ic = 2*tiitg; ic < ne0; ic += 2*tptg.x) {
+ if (ic < n_dims) {
+ const int64_t ib = 0;
// simplified from `(ib * n_dims + ic) * inv_ndims`
const float cur_rot = inv_ndims*ic - ib;
@@ -1615,6 +1759,14 @@ kernel void kernel_rope(
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ } else {
+ const int64_t i0 = ic;
+
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ dst_data[0] = src[0];
+ dst_data[1] = src[1];
}
}
}
@@ -1623,9 +1775,29 @@ kernel void kernel_rope(
template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope;
template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope;
-kernel void kernel_im2col_f16(
+typedef void (im2col_t)(
device const float * x,
- device half * dst,
+ device char * dst,
+ constant int32_t & ofs0,
+ constant int32_t & ofs1,
+ constant int32_t & IW,
+ constant int32_t & IH,
+ constant int32_t & CHW,
+ constant int32_t & s0,
+ constant int32_t & s1,
+ constant int32_t & p0,
+ constant int32_t & p1,
+ constant int32_t & d0,
+ constant int32_t & d1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tgpg[[threadgroups_per_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]);
+
+template
+kernel void kernel_im2col(
+ device const float * x,
+ device char * dst,
constant int32_t & ofs0,
constant int32_t & ofs1,
constant int32_t & IW,
@@ -1648,30 +1820,126 @@ kernel void kernel_im2col_f16(
(tpitg[0] * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
(tgpig[0] * (ntg[1] * ntg[2]) + tpitg[1] * ntg[2] + tpitg[2]);
+ device T * pdst = (device T *) (dst);
+
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
- dst[offset_dst] = 0.0f;
+ pdst[offset_dst] = 0.0f;
} else {
const int32_t offset_src = tpitg[0] * ofs0 + tgpig[0] * ofs1;
- dst[offset_dst] = x[offset_src + iih * IW + iiw];
+ pdst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
-// bitonic sort implementation following the CUDA kernels as reference
-typedef void (argsort_t)(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]);
+template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
+template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
-template
-kernel void kernel_argsort_f32_i32(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]) {
- // bitonic sort
+kernel void kernel_upscale_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & sf,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1/sf;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = src0_ptr[i0/sf];
+ }
+}
+
+kernel void kernel_pad_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00) {
+ dst_ptr[i0] = src0_ptr[i0];
+ } else {
+ dst_ptr[i0] = 0.0f;
+ }
+ }
+
+ return;
+ }
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = 0.0f;
+ }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
int col = tpitg[0];
int row = tgpig[1];
@@ -1708,6 +1976,14 @@ kernel void kernel_argsort_f32_i32(
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32;
+kernel void kernel_leaky_relu_f32(
+ device const float * src0,
+ device float * dst,
+ constant float & slope,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
kernel void kernel_cpy_f16_f16(
device const half * src0,
device half * dst,
@@ -2066,9 +2342,9 @@ kernel void kernel_cpy_f32_q4_1(
}
kernel void kernel_concat(
- device const char * src0,
- device const char * src1,
- device char * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -2105,7 +2381,7 @@ kernel void kernel_concat(
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
@@ -2195,21 +2471,24 @@ typedef struct {
} block_q6_K;
// 210 bytes / block
-static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
- uchar4 r;
- if (j < 4) {
- r[0] = q[j+0] & 63;
- r[2] = q[j+1] & 63;
- r[1] = q[j+4] & 63;
- r[3] = q[j+5] & 63;
- } else {
- r[0] = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
- r[2] = (q[j+5] & 0xF) | ((q[j-3] >> 6) << 4);
- r[1] = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
- r[3] = (q[j+5] >> 4) | ((q[j+1] >> 6) << 4);
- }
- return r;
-}
+typedef struct {
+ half d;
+ uint16_t qs[QK_K/8];
+} block_iq2_xxs;
+// 66 bytes / block for QK_K = 256, so 2.0625 bpw
+
+typedef struct {
+ half d;
+ uint16_t qs[QK_K/8];
+ uint8_t scales[QK_K/32];
+} block_iq2_xs;
+// 74 bytes / block for QK_K = 256, so 2.3125 bpw
+
+typedef struct {
+ half d;
+ uint8_t qs[3*QK_K/8];
+} block_iq3_xxs;
+// 98 bytes / block for QK_K = 256, so 3.0625 bpw
//====================================== dot products =========================
@@ -2369,14 +2648,21 @@ kernel void kernel_mul_mv_q2_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2626,14 +2912,21 @@ kernel void kernel_mul_mv_q3_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -2769,8 +3062,8 @@ void kernel_mul_mv_q4_K_f32_impl(
constant uint & r2,
constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int ix = tiisg/4; // 0...7
const int it = tiisg%4; // 0...3
@@ -2779,7 +3072,7 @@ void kernel_mul_mv_q4_K_f32_impl(
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
const int ib_row = first_row * nb;
const uint i12 = im%ne12;
@@ -2845,7 +3138,7 @@ void kernel_mul_mv_q4_K_f32_impl(
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
}
}
}
@@ -2857,14 +3150,21 @@ kernel void kernel_mul_mv_q4_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3056,14 +3356,21 @@ kernel void kernel_mul_mv_q5_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3183,14 +3490,21 @@ kernel void kernel_mul_mv_q6_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
@@ -3198,75 +3512,737 @@ kernel void kernel_mul_mv_q6_K_f32(
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
}
-//============================= templates and their specializations =============================
+// ======================= "True" 2-bit
+
+constexpr constant static uint64_t iq2xxs_grid[256] = {
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x08080808082b0808,
+ 0x08080808082b082b, 0x08080808082b2b08, 0x08080808082b2b2b, 0x0808080819080819,
+ 0x0808080819081908, 0x0808080819190808, 0x0808080819192b08, 0x08080808192b0819,
+ 0x08080808192b1908, 0x080808082b080808, 0x080808082b08082b, 0x080808082b082b2b,
+ 0x080808082b2b082b, 0x0808081908080819, 0x0808081908081908, 0x0808081908190808,
+ 0x0808081908191919, 0x0808081919080808, 0x080808192b081908, 0x080808192b192b08,
+ 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b082b082b, 0x0808082b2b08082b,
+ 0x0808190808080819, 0x0808190808081908, 0x0808190808190808, 0x08081908082b0819,
+ 0x08081908082b1908, 0x0808190819080808, 0x080819081908082b, 0x0808190819082b08,
+ 0x08081908192b0808, 0x080819082b080819, 0x080819082b081908, 0x080819082b190808,
+ 0x080819082b2b1908, 0x0808191908080808, 0x080819190808082b, 0x0808191908082b08,
+ 0x08081919082b0808, 0x080819191908192b, 0x08081919192b2b19, 0x080819192b080808,
+ 0x080819192b190819, 0x0808192b08082b19, 0x0808192b08190808, 0x0808192b19080808,
+ 0x0808192b2b081908, 0x0808192b2b2b1908, 0x08082b0808080808, 0x08082b0808081919,
+ 0x08082b0808082b08, 0x08082b0808191908, 0x08082b08082b2b08, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b081919082b, 0x08082b082b082b08,
+ 0x08082b1908081908, 0x08082b1919080808, 0x08082b2b0808082b, 0x08082b2b08191908,
+ 0x0819080808080819, 0x0819080808081908, 0x0819080808190808, 0x08190808082b0819,
+ 0x0819080819080808, 0x08190808192b0808, 0x081908082b081908, 0x081908082b190808,
+ 0x081908082b191919, 0x0819081908080808, 0x0819081908082b08, 0x08190819082b0808,
+ 0x0819081919190808, 0x0819081919192b2b, 0x081908192b080808, 0x0819082b082b1908,
+ 0x0819082b19081919, 0x0819190808080808, 0x0819190808082b08, 0x08191908082b0808,
+ 0x08191908082b1919, 0x0819190819082b19, 0x081919082b080808, 0x0819191908192b08,
+ 0x08191919192b082b, 0x0819192b08080808, 0x0819192b0819192b, 0x08192b0808080819,
+ 0x08192b0808081908, 0x08192b0808190808, 0x08192b0819080808, 0x08192b082b080819,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b192b2b0808, 0x08192b2b19190819,
+ 0x082b080808080808, 0x082b08080808082b, 0x082b080808082b2b, 0x082b080819081908,
+ 0x082b0808192b0819, 0x082b08082b080808, 0x082b08082b08082b, 0x082b0819082b2b19,
+ 0x082b081919082b08, 0x082b082b08080808, 0x082b082b0808082b, 0x082b190808080819,
+ 0x082b190808081908, 0x082b190808190808, 0x082b190819080808, 0x082b19081919192b,
+ 0x082b191908080808, 0x082b191919080819, 0x082b1919192b1908, 0x082b192b2b190808,
+ 0x082b2b0808082b08, 0x082b2b08082b0808, 0x082b2b082b191908, 0x082b2b2b19081908,
+ 0x1908080808080819, 0x1908080808081908, 0x1908080808190808, 0x1908080808192b08,
+ 0x19080808082b0819, 0x19080808082b1908, 0x1908080819080808, 0x1908080819082b08,
+ 0x190808081919192b, 0x19080808192b0808, 0x190808082b080819, 0x190808082b081908,
+ 0x190808082b190808, 0x1908081908080808, 0x19080819082b0808, 0x19080819192b0819,
+ 0x190808192b080808, 0x190808192b081919, 0x1908082b08080819, 0x1908082b08190808,
+ 0x1908082b19082b08, 0x1908082b1919192b, 0x1908082b192b2b08, 0x1908190808080808,
+ 0x1908190808082b08, 0x19081908082b0808, 0x190819082b080808, 0x190819082b192b19,
+ 0x190819190819082b, 0x19081919082b1908, 0x1908192b08080808, 0x19082b0808080819,
+ 0x19082b0808081908, 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919,
+ 0x19082b1908080808, 0x19082b1919192b08, 0x19082b19192b0819, 0x19082b192b08082b,
+ 0x19082b2b19081919, 0x19082b2b2b190808, 0x1919080808080808, 0x1919080808082b08,
+ 0x1919080808190819, 0x1919080808192b19, 0x19190808082b0808, 0x191908082b080808,
+ 0x191908082b082b08, 0x1919081908081908, 0x191908191908082b, 0x191908192b2b1908,
+ 0x1919082b2b190819, 0x191919082b190808, 0x191919082b19082b, 0x1919191908082b2b,
+ 0x1919192b08080819, 0x1919192b19191908, 0x19192b0808080808, 0x19192b0808190819,
+ 0x19192b0808192b19, 0x19192b08192b1908, 0x19192b1919080808, 0x19192b2b08082b08,
+ 0x192b080808081908, 0x192b080808190808, 0x192b080819080808, 0x192b0808192b2b08,
+ 0x192b081908080808, 0x192b081919191919, 0x192b082b08192b08, 0x192b082b192b0808,
+ 0x192b190808080808, 0x192b190808081919, 0x192b191908190808, 0x192b19190819082b,
+ 0x192b19192b081908, 0x192b2b081908082b, 0x2b08080808080808, 0x2b0808080808082b,
+ 0x2b08080808082b2b, 0x2b08080819080819, 0x2b0808082b08082b, 0x2b08081908081908,
+ 0x2b08081908192b08, 0x2b08081919080808, 0x2b08082b08190819, 0x2b08190808080819,
+ 0x2b08190808081908, 0x2b08190808190808, 0x2b08190808191919, 0x2b08190819080808,
+ 0x2b081908192b0808, 0x2b08191908080808, 0x2b0819191908192b, 0x2b0819192b191908,
+ 0x2b08192b08082b19, 0x2b08192b19080808, 0x2b08192b192b0808, 0x2b082b080808082b,
+ 0x2b082b1908081908, 0x2b082b2b08190819, 0x2b19080808081908, 0x2b19080808190808,
+ 0x2b190808082b1908, 0x2b19080819080808, 0x2b1908082b2b0819, 0x2b1908190819192b,
+ 0x2b1908192b080808, 0x2b19082b19081919, 0x2b19190808080808, 0x2b191908082b082b,
+ 0x2b19190819081908, 0x2b19191919190819, 0x2b192b082b080819, 0x2b192b19082b0808,
+ 0x2b2b08080808082b, 0x2b2b080819190808, 0x2b2b08082b081919, 0x2b2b081908082b19,
+ 0x2b2b082b08080808, 0x2b2b190808192b08, 0x2b2b2b0819190808, 0x2b2b2b1908081908,
+};
-// NOTE: this is not dequantizing - we are simply fitting the template
-template
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
- float4x4 temp = *(((device float4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
+constexpr constant static uint64_t iq2xs_grid[512] = {
+ 0x0808080808080808, 0x080808080808082b, 0x0808080808081919, 0x0808080808082b08,
+ 0x0808080808082b2b, 0x0808080808190819, 0x0808080808191908, 0x080808080819192b,
+ 0x0808080808192b19, 0x08080808082b0808, 0x08080808082b082b, 0x08080808082b1919,
+ 0x08080808082b2b08, 0x0808080819080819, 0x0808080819081908, 0x080808081908192b,
+ 0x0808080819082b19, 0x0808080819190808, 0x080808081919082b, 0x0808080819191919,
+ 0x0808080819192b08, 0x08080808192b0819, 0x08080808192b1908, 0x080808082b080808,
+ 0x080808082b08082b, 0x080808082b081919, 0x080808082b082b08, 0x080808082b190819,
+ 0x080808082b191908, 0x080808082b192b19, 0x080808082b2b0808, 0x0808081908080819,
+ 0x0808081908081908, 0x080808190808192b, 0x0808081908082b19, 0x0808081908190808,
+ 0x080808190819082b, 0x0808081908191919, 0x0808081908192b08, 0x0808081908192b2b,
+ 0x08080819082b0819, 0x08080819082b1908, 0x0808081919080808, 0x080808191908082b,
+ 0x0808081919081919, 0x0808081919082b08, 0x0808081919190819, 0x0808081919191908,
+ 0x08080819192b0808, 0x08080819192b2b08, 0x080808192b080819, 0x080808192b081908,
+ 0x080808192b190808, 0x0808082b08080808, 0x0808082b0808082b, 0x0808082b08081919,
+ 0x0808082b08082b08, 0x0808082b08190819, 0x0808082b08191908, 0x0808082b082b0808,
+ 0x0808082b19080819, 0x0808082b19081908, 0x0808082b19190808, 0x0808082b19191919,
+ 0x0808082b2b080808, 0x0808082b2b082b2b, 0x0808190808080819, 0x0808190808081908,
+ 0x080819080808192b, 0x0808190808082b19, 0x0808190808190808, 0x080819080819082b,
+ 0x0808190808191919, 0x0808190808192b08, 0x08081908082b0819, 0x08081908082b1908,
+ 0x0808190819080808, 0x080819081908082b, 0x0808190819081919, 0x0808190819082b08,
+ 0x0808190819190819, 0x0808190819191908, 0x080819081919192b, 0x08081908192b0808,
+ 0x080819082b080819, 0x080819082b081908, 0x080819082b190808, 0x0808191908080808,
+ 0x080819190808082b, 0x0808191908081919, 0x0808191908082b08, 0x0808191908190819,
+ 0x0808191908191908, 0x08081919082b0808, 0x0808191919080819, 0x0808191919081908,
+ 0x0808191919190808, 0x08081919192b0819, 0x080819192b080808, 0x0808192b08080819,
+ 0x0808192b08081908, 0x0808192b08190808, 0x0808192b082b192b, 0x0808192b19080808,
+ 0x0808192b1908082b, 0x0808192b2b081908, 0x08082b0808080808, 0x08082b080808082b,
+ 0x08082b0808081919, 0x08082b0808082b08, 0x08082b0808082b2b, 0x08082b0808190819,
+ 0x08082b0808191908, 0x08082b08082b0808, 0x08082b08082b1919, 0x08082b0819080819,
+ 0x08082b0819081908, 0x08082b0819190808, 0x08082b0819192b08, 0x08082b082b080808,
+ 0x08082b082b2b0808, 0x08082b082b2b2b2b, 0x08082b1908080819, 0x08082b1908081908,
+ 0x08082b1908190808, 0x08082b1919080808, 0x08082b192b080819, 0x08082b192b082b19,
+ 0x08082b2b08080808, 0x08082b2b082b0808, 0x08082b2b082b2b08, 0x08082b2b2b19192b,
+ 0x08082b2b2b2b0808, 0x0819080808080819, 0x0819080808081908, 0x081908080808192b,
+ 0x0819080808082b19, 0x0819080808190808, 0x081908080819082b, 0x0819080808191919,
+ 0x0819080808192b08, 0x08190808082b0819, 0x08190808082b1908, 0x0819080819080808,
+ 0x081908081908082b, 0x0819080819081919, 0x0819080819082b08, 0x0819080819190819,
+ 0x0819080819191908, 0x08190808192b0808, 0x08190808192b2b2b, 0x081908082b080819,
+ 0x081908082b081908, 0x081908082b190808, 0x0819081908080808, 0x081908190808082b,
+ 0x0819081908081919, 0x0819081908082b08, 0x0819081908190819, 0x0819081908191908,
+ 0x08190819082b0808, 0x0819081919080819, 0x0819081919081908, 0x0819081919190808,
+ 0x081908192b080808, 0x081908192b191908, 0x081908192b19192b, 0x0819082b08080819,
+ 0x0819082b08081908, 0x0819082b0808192b, 0x0819082b08190808, 0x0819082b19080808,
+ 0x0819082b192b0808, 0x0819190808080808, 0x081919080808082b, 0x0819190808081919,
+ 0x0819190808082b08, 0x0819190808190819, 0x0819190808191908, 0x08191908082b0808,
+ 0x0819190819080819, 0x0819190819081908, 0x0819190819082b19, 0x0819190819190808,
+ 0x08191908192b1908, 0x081919082b080808, 0x0819191908080819, 0x0819191908081908,
+ 0x0819191908190808, 0x0819191919080808, 0x0819192b08080808, 0x0819192b08191908,
+ 0x0819192b19082b19, 0x08192b0808080819, 0x08192b0808081908, 0x08192b0808190808,
+ 0x08192b080819082b, 0x08192b0819080808, 0x08192b0819191908, 0x08192b082b08192b,
+ 0x08192b1908080808, 0x08192b1908081919, 0x08192b19192b192b, 0x08192b2b19190819,
+ 0x08192b2b2b2b2b19, 0x082b080808080808, 0x082b08080808082b, 0x082b080808081919,
+ 0x082b080808082b08, 0x082b080808082b2b, 0x082b080808190819, 0x082b080808191908,
+ 0x082b0808082b0808, 0x082b080819080819, 0x082b080819081908, 0x082b080819190808,
+ 0x082b08082b080808, 0x082b08082b2b0808, 0x082b081908080819, 0x082b081908081908,
+ 0x082b081908190808, 0x082b081919080808, 0x082b081919082b08, 0x082b0819192b1919,
+ 0x082b082b08080808, 0x082b082b082b082b, 0x082b082b2b080808, 0x082b082b2b2b2b08,
+ 0x082b190808080819, 0x082b190808081908, 0x082b190808190808, 0x082b1908082b2b19,
+ 0x082b190819080808, 0x082b191908080808, 0x082b191919080819, 0x082b19191919082b,
+ 0x082b19192b192b19, 0x082b192b08080819, 0x082b192b08192b2b, 0x082b192b2b2b192b,
+ 0x082b2b0808080808, 0x082b2b0808082b08, 0x082b2b0808082b2b, 0x082b2b08082b0808,
+ 0x082b2b0819191919, 0x082b2b082b082b08, 0x082b2b082b2b082b, 0x082b2b19192b2b08,
+ 0x082b2b192b190808, 0x082b2b2b08082b08, 0x082b2b2b082b0808, 0x082b2b2b2b08082b,
+ 0x082b2b2b2b082b08, 0x082b2b2b2b082b2b, 0x1908080808080819, 0x1908080808081908,
+ 0x190808080808192b, 0x1908080808082b19, 0x1908080808190808, 0x190808080819082b,
+ 0x1908080808191919, 0x1908080808192b08, 0x19080808082b0819, 0x19080808082b1908,
+ 0x1908080819080808, 0x190808081908082b, 0x1908080819081919, 0x1908080819082b08,
+ 0x1908080819082b2b, 0x1908080819190819, 0x1908080819191908, 0x19080808192b0808,
+ 0x19080808192b1919, 0x190808082b080819, 0x190808082b081908, 0x190808082b190808,
+ 0x1908081908080808, 0x190808190808082b, 0x1908081908081919, 0x1908081908082b08,
+ 0x1908081908190819, 0x1908081908191908, 0x19080819082b0808, 0x1908081919080819,
+ 0x1908081919081908, 0x1908081919190808, 0x190808192b080808, 0x190808192b081919,
+ 0x190808192b2b082b, 0x1908082b08080819, 0x1908082b08081908, 0x1908082b08190808,
+ 0x1908082b0819082b, 0x1908082b082b2b19, 0x1908082b19080808, 0x1908190808080808,
+ 0x190819080808082b, 0x1908190808081919, 0x1908190808082b08, 0x1908190808190819,
+ 0x1908190808191908, 0x1908190808192b19, 0x19081908082b0808, 0x1908190819080819,
+ 0x1908190819081908, 0x1908190819190808, 0x190819082b080808, 0x190819082b191908,
+ 0x1908191908080819, 0x1908191908081908, 0x1908191908190808, 0x19081919082b1908,
+ 0x1908191919080808, 0x190819192b192b2b, 0x1908192b08080808, 0x1908192b08082b2b,
+ 0x1908192b19081908, 0x1908192b19190808, 0x19082b0808080819, 0x19082b0808081908,
+ 0x19082b0808190808, 0x19082b0819080808, 0x19082b0819081919, 0x19082b0819191908,
+ 0x19082b08192b082b, 0x19082b1908080808, 0x19082b1908190819, 0x19082b1919081908,
+ 0x19082b1919190808, 0x19082b19192b2b19, 0x19082b2b08081908, 0x1919080808080808,
+ 0x191908080808082b, 0x1919080808081919, 0x1919080808082b08, 0x1919080808190819,
+ 0x1919080808191908, 0x19190808082b0808, 0x19190808082b2b08, 0x1919080819080819,
+ 0x1919080819081908, 0x1919080819190808, 0x191908082b080808, 0x1919081908080819,
+ 0x1919081908081908, 0x1919081908190808, 0x1919081908191919, 0x1919081919080808,
+ 0x191908191908082b, 0x1919082b08080808, 0x1919082b19081908, 0x1919082b2b2b2b2b,
+ 0x1919190808080819, 0x1919190808081908, 0x1919190808190808, 0x19191908082b0819,
+ 0x1919190819080808, 0x19191908192b0808, 0x191919082b080819, 0x191919082b2b0819,
+ 0x1919191908080808, 0x1919191908082b08, 0x191919192b080808, 0x191919192b082b08,
+ 0x1919192b082b0819, 0x1919192b192b2b08, 0x1919192b2b2b0819, 0x19192b0808080808,
+ 0x19192b0808191908, 0x19192b0819080819, 0x19192b0819190808, 0x19192b082b192b19,
+ 0x19192b1908192b2b, 0x19192b1919080808, 0x19192b191908082b, 0x19192b2b2b081919,
+ 0x192b080808080819, 0x192b080808081908, 0x192b080808190808, 0x192b080819080808,
+ 0x192b080819191908, 0x192b0808192b082b, 0x192b08082b08192b, 0x192b08082b2b2b19,
+ 0x192b081908080808, 0x192b082b082b1908, 0x192b082b19082b2b, 0x192b082b2b19082b,
+ 0x192b190808080808, 0x192b19080819192b, 0x192b191908190808, 0x192b191919080808,
+ 0x192b191919081919, 0x192b19192b2b1908, 0x192b2b0808080819, 0x192b2b08192b2b2b,
+ 0x192b2b19082b1919, 0x192b2b2b0808192b, 0x192b2b2b19191908, 0x192b2b2b192b082b,
+ 0x2b08080808080808, 0x2b0808080808082b, 0x2b08080808081919, 0x2b08080808082b08,
+ 0x2b08080808190819, 0x2b08080808191908, 0x2b080808082b0808, 0x2b080808082b2b2b,
+ 0x2b08080819080819, 0x2b08080819081908, 0x2b08080819190808, 0x2b0808082b080808,
+ 0x2b0808082b08082b, 0x2b0808082b2b2b08, 0x2b0808082b2b2b2b, 0x2b08081908080819,
+ 0x2b08081908081908, 0x2b0808190808192b, 0x2b08081908190808, 0x2b08081919080808,
+ 0x2b08081919190819, 0x2b08081919192b19, 0x2b08082b08080808, 0x2b08082b082b0808,
+ 0x2b08082b2b080808, 0x2b08082b2b08082b, 0x2b08082b2b2b0808, 0x2b08082b2b2b2b08,
+ 0x2b08190808080819, 0x2b08190808081908, 0x2b08190808190808, 0x2b0819080819082b,
+ 0x2b08190808191919, 0x2b08190819080808, 0x2b081908192b0808, 0x2b0819082b082b19,
+ 0x2b08191908080808, 0x2b08191919081908, 0x2b0819192b2b1919, 0x2b08192b08192b08,
+ 0x2b08192b192b2b2b, 0x2b082b0808080808, 0x2b082b0808082b08, 0x2b082b08082b1919,
+ 0x2b082b0819192b2b, 0x2b082b082b080808, 0x2b082b082b08082b, 0x2b082b082b2b2b08,
+ 0x2b082b190808192b, 0x2b082b2b082b082b, 0x2b082b2b2b080808, 0x2b082b2b2b082b08,
+ 0x2b082b2b2b19192b, 0x2b082b2b2b2b2b08, 0x2b19080808080819, 0x2b19080808081908,
+ 0x2b19080808190808, 0x2b19080819080808, 0x2b1908081919192b, 0x2b1908082b081908,
+ 0x2b19081908080808, 0x2b190819082b082b, 0x2b190819192b1908, 0x2b19082b1919192b,
+ 0x2b19082b2b082b19, 0x2b19190808080808, 0x2b19190808081919, 0x2b19190819081908,
+ 0x2b19190819190808, 0x2b19190819192b08, 0x2b191919082b2b19, 0x2b1919192b190808,
+ 0x2b1919192b19082b, 0x2b19192b19080819, 0x2b192b0819190819, 0x2b192b082b2b192b,
+ 0x2b192b1919082b19, 0x2b192b2b08191919, 0x2b192b2b192b0808, 0x2b2b080808080808,
+ 0x2b2b08080808082b, 0x2b2b080808082b08, 0x2b2b080808082b2b, 0x2b2b0808082b0808,
+ 0x2b2b0808082b2b2b, 0x2b2b08082b2b0808, 0x2b2b081919190819, 0x2b2b081919192b19,
+ 0x2b2b08192b2b192b, 0x2b2b082b08080808, 0x2b2b082b0808082b, 0x2b2b082b08082b08,
+ 0x2b2b082b082b2b2b, 0x2b2b082b2b080808, 0x2b2b082b2b2b0808, 0x2b2b190819080808,
+ 0x2b2b19082b191919, 0x2b2b192b192b1919, 0x2b2b192b2b192b08, 0x2b2b2b0808082b2b,
+ 0x2b2b2b08082b0808, 0x2b2b2b08082b082b, 0x2b2b2b08082b2b08, 0x2b2b2b082b2b0808,
+ 0x2b2b2b082b2b2b08, 0x2b2b2b1908081908, 0x2b2b2b192b081908, 0x2b2b2b192b08192b,
+ 0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
+};
-template
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
- half4x4 temp = *(((device half4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
+constexpr constant static uint32_t iq3xxs_grid[256] = {
+ 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414,
+ 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14,
+ 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404,
+ 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e,
+ 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c,
+ 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c,
+ 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34,
+ 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c,
+ 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c,
+ 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04,
+ 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c,
+ 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414,
+ 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434,
+ 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c,
+ 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e,
+ 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24,
+ 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24,
+ 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c,
+ 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c,
+ 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14,
+ 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414,
+ 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e,
+ 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404,
+ 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c,
+ 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c,
+ 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14,
+ 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c,
+ 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c,
+ 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14,
+ 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14,
+ 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c,
+ 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
+};
-template
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float md = -8.h * xb->d;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
- }
-}
+constexpr constant static uint8_t ksigns_iq2xs[128] = {
+ 0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
+ 144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
+ 160, 33, 34, 163, 36, 165, 166, 39, 40, 169, 170, 43, 172, 45, 46, 175,
+ 48, 177, 178, 51, 180, 53, 54, 183, 184, 57, 58, 187, 60, 189, 190, 63,
+ 192, 65, 66, 195, 68, 197, 198, 71, 72, 201, 202, 75, 204, 77, 78, 207,
+ 80, 209, 210, 83, 212, 85, 86, 215, 216, 89, 90, 219, 92, 221, 222, 95,
+ 96, 225, 226, 99, 228, 101, 102, 231, 232, 105, 106, 235, 108, 237, 238, 111,
+ 240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
+};
-template
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float m = xb->m;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
+constexpr constant static uint8_t kmask_iq2xs[8] = {1, 2, 4, 8, 16, 32, 64, 128};
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
- }
-}
+void kernel_mul_mv_iq2_xxs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
-template
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
- const float d = xb->d;
- const float md = -16.h * xb->d;
- const ushort mask = il ? 0x00F0 : 0x000F;
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
- const uint32_t qh = *((device const uint32_t *)xb->qh);
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
- const int x_mv = il ? 4 : 0;
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- // combine the 4-bits from qs with the 5th bit
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+#if QK_K == 256
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xxs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ device const uint8_t * aux8 = (device const uint8_t *)q2;
+ const uint32_t aux32 = q2[2] | (q2[3] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float sum = 0;
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 8; ++j) {
+ sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * sum;
+
+ dh += nb*sizeof(block_iq2_xxs)/2;
+ q2 += nb*sizeof(block_iq2_xxs)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+#else
+ // TODO
+#endif
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_iq2_xxs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq2_xs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
+ {
+ int nval = 8;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+#if QK_K == 256
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq2_xs * xr = x + ibl;
+ device const uint16_t * q2 = xr->qs + 4 * ib;
+ device const uint8_t * sc = xr->scales + ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const uint8_t ls1 = sc[0] & 0xf;
+ const uint8_t ls2 = sc[0] >> 4;
+ const float d1 = db * (0.5f + ls1);
+ const float d2 = db * (0.5f + ls2);
+
+ float sum1 = 0, sum2 = 0;
+ for (int l = 0; l < 2; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ for (int l = 2; l < 4; ++l) {
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
+ const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ for (int j = 0; j < 8; ++j) {
+ sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d1 * sum1 + d2 * sum2;
+
+ dh += nb*sizeof(block_iq2_xs)/2;
+ q2 += nb*sizeof(block_iq2_xs)/2;
+ sc += nb*sizeof(block_iq2_xs);
+ }
+
+ y4 += 32 * 32;
+ }
+#else
+ // TODO
+#endif
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_iq3_xxs_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const int nb = ne00/QK_K;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
+
+ const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int ib_row = first_row * nb;
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
+
+ float yl[32];
+ float sumf[N_DST]={0.f}, all_sum;
+
+ const int nb32 = nb * (QK_K / 32);
+
+ threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+ threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ {
+ int nval = 4;
+ int pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+ nval = 2;
+ pos = (32*sgitg + tiisg)*nval;
+ for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+
+#if QK_K == 256
+ const int ix = tiisg;
+
+ device const float * y4 = y + 32 * ix;
+
+ for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
+
+ for (int i = 0; i < 32; ++i) {
+ yl[i] = y4[i];
+ }
+
+ const int ibl = ib32 / (QK_K / 32);
+ const int ib = ib32 % (QK_K / 32);
+
+ device const block_iq3_xxs * xr = x + ibl;
+ device const uint8_t * q3 = xr->qs + 8 * ib;
+ device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
+ device const half * dh = &xr->d;
+
+ for (int row = 0; row < N_DST; row++) {
+
+ const float db = dh[0];
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float d = db * (0.5f + (aux32 >> 28));
+
+ float2 sum = {0};
+ for (int l = 0; l < 4; ++l) {
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
+ const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ for (int j = 0; j < 4; ++j) {
+ sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
+ sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
+ }
+ }
+ sumf[row] += d * (sum[0] + sum[1]);
+
+ dh += nb*sizeof(block_iq3_xxs)/2;
+ q3 += nb*sizeof(block_iq3_xxs);
+ gas += nb*sizeof(block_iq3_xxs)/2;
+ }
+
+ y4 += 32 * 32;
+ }
+#else
+ // TODO
+#endif
+
+ for (int row = 0; row < N_DST; ++row) {
+ all_sum = simd_sum(sumf[row]);
+ if (tiisg == 0) {
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+ }
+ }
+}
+
+[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_iq3_xxs_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+}
+
+
+//============================= templates and their specializations =============================
+
+// NOTE: this is not dequantizing - we are simply fitting the template
+template
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ float4x4 temp = *(((device float4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+ }
+}
+
+template
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
+ }
+}
+
+template
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
@@ -3308,17 +4284,17 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
- for (int i=0;i<16;i++) {
+ for (int i = 0; i < 16; i++) {
reg[i/4][i%4] = (qs[i + 16*il] * d);
}
}
template
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const half d = xb->d;
- const half min = xb->dmin;
+ const float d = xb->d;
+ const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
- half dl, ml;
+ float dl, ml;
uint8_t sc = xb->scales[il];
#if QK_K == 256
@@ -3350,8 +4326,8 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
: (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
- const half ml = 4.h * dl;
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ const float ml = 4.f * dl;
il = (il/2) & 3;
const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
@@ -3388,10 +4364,10 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
@@ -3418,13 +4394,13 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const half qh_val = il<2 ? 16.h : 256.h;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
@@ -3451,17 +4427,17 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1);
- half sc = scales[(il%2) + 2 * ((il/2))];
+ float sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2) & 3;
#else
ql = ql + 16 * (il&1);
- half sc = scales[il];
+ float sc = scales[il];
#endif
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
- const half coef = il>1 ? 1.f/16.h : 1.h;
- const half ml = d_all * sc * 32.h;
- const half dl = d_all * sc * coef;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ const float ml = d_all * sc * 32.f;
+ const float dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) {
const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
: ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
@@ -3469,6 +4445,79 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
}
}
+template
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+}
+
+template
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * q3 = xb->qs + 8*ib32;
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+}
+
template
kernel void kernel_get_rows(
device const void * src0,
@@ -3559,6 +4608,35 @@ kernel void kernel_get_rows_f16(
}
}
+kernel void kernel_get_rows_i32(
+ device const void * src0,
+ device const char * src1,
+ device int32_t * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
#define BLOCK_SIZE_K 32
@@ -3577,12 +4655,12 @@ void kernel_mul_mm_impl(device const uchar * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
- constant int64_t & nb01,
- constant int64_t & nb02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
constant int64_t & ne12,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@@ -3625,7 +4703,144 @@ void kernel_mul_mm_impl(device const uchar * src0,
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
+ nb12 * im
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+ + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ // load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ #pragma unroll(16)
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
+ #pragma unroll(4)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+ device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ if (sgitg == 0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
+ }
+}
+
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+template
+void kernel_mul_mm_id_impl(
+ device const uchar * src0,
+ device const uchar * src1,
+ thread short * src1ids,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ int64_t ne1,
+ constant uint & r2,
+ constant uint & r3,
+ threadgroup uchar * shared_memory,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+
+ if (r1 * BLOCK_SIZE_N >= ne1) return;
+
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+
+ const uint i12 = im%ne12;
+ const uint i13 = im/ne12;
+
+ uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+ ushort offset1 = il/nl;
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = (device const float *)(src1
+ + nb12 * im
+ + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
@@ -3634,7 +4849,6 @@ void kernel_mul_mm_impl(device const uchar * src0,
dequantize_func(x, il, temp_a);
threadgroup_barrier(mem_flags::mem_threadgroup);
- #pragma unroll(16)
for (int i = 0; i < 16; i++) {
*(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
@@ -3653,14 +4867,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
- #pragma unroll(4)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
- #pragma unroll(4)
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
- #pragma unroll(2)
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
}
@@ -3668,21 +4879,13 @@ void kernel_mul_mm_impl(device const uchar * src0,
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
- #pragma unroll(8)
for (int i = 0; i < 8; i++){
simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
}
}
}
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
- for (int i = 0; i < 8; i++) {
- simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
- }
- } else {
- // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ {
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
@@ -3692,11 +4895,11 @@ void kernel_mul_mm_impl(device const uchar * src0,
threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
if (sgitg == 0) {
for (int i = 0; i < n_rows; i++) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
@@ -3709,12 +4912,12 @@ kernel void kernel_mul_mm(device const uchar * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
- constant int64_t & nb01,
- constant int64_t & nb02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
constant int64_t & ne12,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@@ -3749,20 +4952,20 @@ template(
- src0[id],
- src1 + bid*nb11,
- (device float *) (dst + bid*nb1),
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
+ src1ids[_ne1++] = i1;
+ }
+ }
+
+ kernel_mul_mm_id_impl(
+ src0s[id],
+ src1,
+ src1ids,
+ dst,
ne00,
ne02,
nb01,
@@ -3799,7 +5012,7 @@ kernel void kernel_mul_mm_id(
nb11,
nb12,
ne0,
- ne1,
+ _ne1,
r2,
r3,
shared_memory,
@@ -3844,6 +5057,9 @@ template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows;
//
// matrix-matrix multiplication
@@ -3855,12 +5071,12 @@ typedef void (mat_mm_t)(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
- constant int64_t & nb01,
- constant int64_t & nb02,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
constant int64_t & ne12,
- constant int64_t & nb10,
- constant int64_t & nb11,
- constant int64_t & nb12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & r2,
@@ -3880,6 +5096,9 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm;
//
// indirect matrix-matrix multiplication
@@ -3888,20 +5107,20 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq2_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
+template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id;
//
// matrix-vector multiplication
@@ -3937,8 +5159,8 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu
kernel void kernel_mul_mv_id_f32_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -3954,7 +5176,7 @@ kernel void kernel_mul_mv_id_f32_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -3981,7 +5203,7 @@ kernel void kernel_mul_mv_id_f32_f32(
kernel_mul_mv_f32_f32_impl(
src0[id],
src1 + bid*nb11,
- (device float *) (dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4006,8 +5228,8 @@ kernel void kernel_mul_mv_id_f32_f32(
kernel void kernel_mul_mv_id_f16_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4023,7 +5245,7 @@ kernel void kernel_mul_mv_id_f16_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4050,7 +5272,7 @@ kernel void kernel_mul_mv_id_f16_f32(
kernel_mul_mv_f16_f32_impl(
src0[id],
src1 + bid*nb11,
- (device float *) (dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4075,8 +5297,8 @@ kernel void kernel_mul_mv_id_f16_f32(
kernel void kernel_mul_mv_id_q8_0_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4092,7 +5314,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4119,7 +5341,7 @@ kernel void kernel_mul_mv_id_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4138,8 +5360,8 @@ kernel void kernel_mul_mv_id_q8_0_f32(
kernel void kernel_mul_mv_id_q4_0_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4155,7 +5377,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4182,7 +5404,7 @@ kernel void kernel_mul_mv_id_q4_0_f32(
mul_vec_q_n_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4201,8 +5423,8 @@ kernel void kernel_mul_mv_id_q4_0_f32(
kernel void kernel_mul_mv_id_q4_1_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4218,7 +5440,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4245,7 +5467,7 @@ kernel void kernel_mul_mv_id_q4_1_f32(
mul_vec_q_n_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4264,8 +5486,8 @@ kernel void kernel_mul_mv_id_q4_1_f32(
kernel void kernel_mul_mv_id_q5_0_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4281,7 +5503,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4308,7 +5530,7 @@ kernel void kernel_mul_mv_id_q5_0_f32(
mul_vec_q_n_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4327,8 +5549,8 @@ kernel void kernel_mul_mv_id_q5_0_f32(
kernel void kernel_mul_mv_id_q5_1_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4344,7 +5566,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4371,7 +5593,7 @@ kernel void kernel_mul_mv_id_q5_1_f32(
mul_vec_q_n_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4390,8 +5612,8 @@ kernel void kernel_mul_mv_id_q5_1_f32(
kernel void kernel_mul_mv_id_q2_K_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4407,7 +5629,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4434,7 +5656,7 @@ kernel void kernel_mul_mv_id_q2_K_f32(
kernel_mul_mv_q2_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4453,8 +5675,8 @@ kernel void kernel_mul_mv_id_q2_K_f32(
kernel void kernel_mul_mv_id_q3_K_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4470,7 +5692,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4497,7 +5719,7 @@ kernel void kernel_mul_mv_id_q3_K_f32(
kernel_mul_mv_q3_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4516,8 +5738,8 @@ kernel void kernel_mul_mv_id_q3_K_f32(
kernel void kernel_mul_mv_id_q4_K_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4533,7 +5755,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4560,7 +5782,7 @@ kernel void kernel_mul_mv_id_q4_K_f32(
kernel_mul_mv_q4_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4579,8 +5801,8 @@ kernel void kernel_mul_mv_id_q4_K_f32(
kernel void kernel_mul_mv_id_q5_K_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4596,7 +5818,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4623,7 +5845,7 @@ kernel void kernel_mul_mv_id_q5_K_f32(
kernel_mul_mv_q5_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4642,8 +5864,8 @@ kernel void kernel_mul_mv_id_q5_K_f32(
kernel void kernel_mul_mv_id_q6_K_f32(
device const char * ids,
device const char * src1,
- device uchar * dst,
- constant int64_t & nbi1,
+ device float * dst,
+ constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
@@ -4659,7 +5881,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant int64_t & nb1,
+ constant uint64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
@@ -4686,7 +5908,201 @@ kernel void kernel_mul_mv_id_q6_K_f32(
kernel_mul_mv_q6_K_f32_impl(
src0[id],
(device const float *) (src1 + bid*nb11),
- (device float *) ( dst + bid*nb1),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
+kernel void kernel_mul_mv_id_iq2_xxs_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_iq2_xxs_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_values,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
+kernel void kernel_mul_mv_id_iq2_xs_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_iq2_xs_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ shared_values,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
+kernel void kernel_mul_mv_id_iq3_xxs_f32(
+ device const char * ids,
+ device const char * src1,
+ device float * dst,
+ constant uint64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ threadgroup int8_t * shared_values [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_iq3_xxs_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ dst + bid*ne0,
ne00,
ne01,
ne02,
@@ -4696,6 +6112,7 @@ kernel void kernel_mul_mv_id_q6_K_f32(
ne1,
r2,
r3,
+ shared_values,
tgpig,
tiisg,
sgitg);
diff --git a/LLama/runtimes/deps/osx-arm64/libllama.dylib b/LLama/runtimes/deps/osx-arm64/libllama.dylib
index 712a0be4..853998a7 100644
Binary files a/LLama/runtimes/deps/osx-arm64/libllama.dylib and b/LLama/runtimes/deps/osx-arm64/libllama.dylib differ
diff --git a/LLama/runtimes/deps/osx-x64/libllama.dylib b/LLama/runtimes/deps/osx-x64/libllama.dylib
index c976111a..208bfe84 100644
Binary files a/LLama/runtimes/deps/osx-x64/libllama.dylib and b/LLama/runtimes/deps/osx-x64/libllama.dylib differ
diff --git a/README.md b/README.md
index a73fb3c7..18dd521d 100644
--- a/README.md
+++ b/README.md
@@ -222,6 +222,7 @@ If you want to compile llama.cpp yourself you **must** use the exact commit ID l
| v0.7.0, v0.8.0 | [Thespis-13B](https://huggingface.co/TheBloke/Thespis-13B-v0.5-GGUF/tree/main?not-for-all-audiences=true), [LLaMA2-7B](https://huggingface.co/TheBloke/llama-2-7B-Guanaco-QLoRA-GGUF) | [`207b519`](https://github.com/ggerganov/llama.cpp/commit/207b51900e15cc7f89763a3bb1c565fe11cbb45d) |
| v0.8.1 | | [`e937066`](https://github.com/ggerganov/llama.cpp/commit/e937066420b79a757bf80e9836eb12b88420a218) |
| v0.9.0, v0.9.1 | [Mixtral-8x7B](https://huggingface.co/TheBloke/Mixtral-8x7B-v0.1-GGUF) | [`9fb13f9`](https://github.com/ggerganov/llama.cpp/blob/9fb13f95840c722ad419f390dc8a9c86080a3700) |
+| v0.10.0 | [Phi2](https://huggingface.co/TheBloke/phi-2-GGUF) | [`d71ac90`](https://github.com/ggerganov/llama.cpp/tree/d71ac90985854b0905e1abba778e407e17f9f887) |
## License