diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs index 232d9f28..2114d0be 100644 --- a/LLama/LLamaQuantizer.cs +++ b/LLama/LLamaQuantizer.cs @@ -20,14 +20,22 @@ namespace LLama /// Thread to be used during the quantization. By default it's the physical core number. /// Whether the quantization is successful. /// - public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1) + public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true, + bool quantizeOutputTensor = false) { if (!ValidateFtype(ftype)) { throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " + $"to perform quantization."); } - return NativeApi.llama_model_quantize(srcFileName, dstFilename, ftype, nthread) == 0; + + var quantizeParams = NativeApi.llama_model_quantize_default_params(); + quantizeParams.ftype = ftype; + quantizeParams.nthread = nthread; + quantizeParams.allow_requantize = allowRequantize; + quantizeParams.quantize_output_tensor = quantizeOutputTensor; + LLamaModelQuantizeParams* p = &quantizeParams; + return NativeApi.llama_model_quantize(srcFileName, dstFilename, p) == 0; } /// @@ -39,9 +47,10 @@ namespace LLama /// Thread to be used during the quantization. By default it's the physical core number. /// Whether the quantization is successful. /// - public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1) + public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1, bool allowRequantize = true, + bool quantizeOutputTensor = false) { - return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread); + return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor); } private static bool ValidateFtype(string ftype) diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs index aa448ab9..3d0e2cab 100644 --- a/LLama/Native/LLamaContextParams.cs +++ b/LLama/Native/LLamaContextParams.cs @@ -9,19 +9,44 @@ namespace LLama.Native [StructLayout(LayoutKind.Sequential)] public struct LLamaContextParams { + /// + /// RNG seed, -1 for random + /// + public int seed; /// /// text context /// public int n_ctx; /// + /// prompt processing batch size + /// + public int n_batch; + /// /// number of layers to store in VRAM /// public int n_gpu_layers; /// - /// RNG seed, -1 for random + /// the GPU that is used for scratch and small tensors /// - public int seed; + public int main_gpu; + /// + /// how to split layers across multiple GPUs + /// + public TensorSplits tensor_split; + /// + /// called with a progress value between 0 and 1, pass NULL to disable + /// + public IntPtr progress_callback; + /// + /// context pointer passed to the progress callback + /// + public IntPtr progress_callback_user_data; + /// + /// if true, reduce VRAM usage at the cost of performance + /// + [MarshalAs(UnmanagedType.I1)] + public bool low_vram; /// /// use fp16 for KV cache /// @@ -52,14 +77,10 @@ namespace LLama.Native /// [MarshalAs(UnmanagedType.I1)] public bool embedding; + } - /// - /// called with a progress value between 0 and 1, pass NULL to disable - /// - public IntPtr progress_callback; - /// - /// context pointer passed to the progress callback - /// - public IntPtr progress_callback_user_data; + public struct TensorSplits + { + public float Item1; } } diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs index 8ecc224f..41159ee2 100644 --- a/LLama/Native/LLamaFtype.cs +++ b/LLama/Native/LLamaFtype.cs @@ -16,5 +16,14 @@ namespace LLama.Native LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors } } diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs new file mode 100644 index 00000000..ebbfb1de --- /dev/null +++ b/LLama/Native/LLamaModelQuantizeParams.cs @@ -0,0 +1,29 @@ +using System; +using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; + +namespace LLama.Native +{ + public struct LLamaModelQuantizeParams + { + /// + /// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() + /// + public int nthread; + /// + /// quantize to this llama_ftype + /// + public LLamaFtype ftype; + /// + /// allow quantizing non-f32/f16 tensors + /// + [MarshalAs(UnmanagedType.I1)] + public bool allow_requantize; + /// + /// quantize output.weight + /// + [MarshalAs(UnmanagedType.I1)] + public bool quantize_output_tensor; + } +} diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs index 978026fc..c1eed4e4 100644 --- a/LLama/Native/NativeApi.Quantize.cs +++ b/LLama/Native/NativeApi.Quantize.cs @@ -17,6 +17,6 @@ namespace LLama.Native /// not great API - very likely to change /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread); + public unsafe static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param); } } diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index e5dd9dfb..4540a123 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -10,6 +10,7 @@ namespace LLama.Native using llama_token = Int32; public unsafe partial class NativeApi { + public static readonly int LLAMA_MAX_DEVICES = 1; static NativeApi() { try @@ -34,6 +35,9 @@ namespace LLama.Native [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern LLamaContextParams llama_context_default_params(); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] + public static extern LLamaModelQuantizeParams llama_model_quantize_default_params(); + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern bool llama_mmap_supported();