diff --git a/LLama/LLamaQuantizer.cs b/LLama/LLamaQuantizer.cs
index 232d9f28..2114d0be 100644
--- a/LLama/LLamaQuantizer.cs
+++ b/LLama/LLamaQuantizer.cs
@@ -20,14 +20,22 @@ namespace LLama
/// Thread to be used during the quantization. By default it's the physical core number.
/// Whether the quantization is successful.
///
- public static bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1)
+ public static unsafe bool Quantize(string srcFileName, string dstFilename, LLamaFtype ftype, int nthread = -1, bool allowRequantize = true,
+ bool quantizeOutputTensor = false)
{
if (!ValidateFtype(ftype))
{
throw new ArgumentException($"The type {Enum.GetName(typeof(LLamaFtype), ftype)} is not a valid type " +
$"to perform quantization.");
}
- return NativeApi.llama_model_quantize(srcFileName, dstFilename, ftype, nthread) == 0;
+
+ var quantizeParams = NativeApi.llama_model_quantize_default_params();
+ quantizeParams.ftype = ftype;
+ quantizeParams.nthread = nthread;
+ quantizeParams.allow_requantize = allowRequantize;
+ quantizeParams.quantize_output_tensor = quantizeOutputTensor;
+ LLamaModelQuantizeParams* p = &quantizeParams;
+ return NativeApi.llama_model_quantize(srcFileName, dstFilename, p) == 0;
}
///
@@ -39,9 +47,10 @@ namespace LLama
/// Thread to be used during the quantization. By default it's the physical core number.
/// Whether the quantization is successful.
///
- public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1)
+ public static bool Quantize(string srcFileName, string dstFilename, string ftype, int nthread = -1, bool allowRequantize = true,
+ bool quantizeOutputTensor = false)
{
- return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread);
+ return Quantize(srcFileName, dstFilename, StringToFtype(ftype), nthread, allowRequantize, quantizeOutputTensor);
}
private static bool ValidateFtype(string ftype)
diff --git a/LLama/Native/LLamaContextParams.cs b/LLama/Native/LLamaContextParams.cs
index aa448ab9..3d0e2cab 100644
--- a/LLama/Native/LLamaContextParams.cs
+++ b/LLama/Native/LLamaContextParams.cs
@@ -9,19 +9,44 @@ namespace LLama.Native
[StructLayout(LayoutKind.Sequential)]
public struct LLamaContextParams
{
+ ///
+ /// RNG seed, -1 for random
+ ///
+ public int seed;
///
/// text context
///
public int n_ctx;
///
+ /// prompt processing batch size
+ ///
+ public int n_batch;
+ ///
/// number of layers to store in VRAM
///
public int n_gpu_layers;
///
- /// RNG seed, -1 for random
+ /// the GPU that is used for scratch and small tensors
///
- public int seed;
+ public int main_gpu;
+ ///
+ /// how to split layers across multiple GPUs
+ ///
+ public TensorSplits tensor_split;
+ ///
+ /// called with a progress value between 0 and 1, pass NULL to disable
+ ///
+ public IntPtr progress_callback;
+ ///
+ /// context pointer passed to the progress callback
+ ///
+ public IntPtr progress_callback_user_data;
+ ///
+ /// if true, reduce VRAM usage at the cost of performance
+ ///
+ [MarshalAs(UnmanagedType.I1)]
+ public bool low_vram;
///
/// use fp16 for KV cache
///
@@ -52,14 +77,10 @@ namespace LLama.Native
///
[MarshalAs(UnmanagedType.I1)]
public bool embedding;
+ }
- ///
- /// called with a progress value between 0 and 1, pass NULL to disable
- ///
- public IntPtr progress_callback;
- ///
- /// context pointer passed to the progress callback
- ///
- public IntPtr progress_callback_user_data;
+ public struct TensorSplits
+ {
+ public float Item1;
}
}
diff --git a/LLama/Native/LLamaFtype.cs b/LLama/Native/LLamaFtype.cs
index 8ecc224f..41159ee2 100644
--- a/LLama/Native/LLamaFtype.cs
+++ b/LLama/Native/LLamaFtype.cs
@@ -16,5 +16,14 @@ namespace LLama.Native
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
+ LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
}
}
diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs
new file mode 100644
index 00000000..ebbfb1de
--- /dev/null
+++ b/LLama/Native/LLamaModelQuantizeParams.cs
@@ -0,0 +1,29 @@
+using System;
+using System.Collections.Generic;
+using System.Runtime.InteropServices;
+using System.Text;
+
+namespace LLama.Native
+{
+ public struct LLamaModelQuantizeParams
+ {
+ ///
+ /// number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency()
+ ///
+ public int nthread;
+ ///
+ /// quantize to this llama_ftype
+ ///
+ public LLamaFtype ftype;
+ ///
+ /// allow quantizing non-f32/f16 tensors
+ ///
+ [MarshalAs(UnmanagedType.I1)]
+ public bool allow_requantize;
+ ///
+ /// quantize output.weight
+ ///
+ [MarshalAs(UnmanagedType.I1)]
+ public bool quantize_output_tensor;
+ }
+}
diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs
index 978026fc..c1eed4e4 100644
--- a/LLama/Native/NativeApi.Quantize.cs
+++ b/LLama/Native/NativeApi.Quantize.cs
@@ -17,6 +17,6 @@ namespace LLama.Native
/// not great API - very likely to change
/// Returns 0 on success
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaFtype ftype, int nthread);
+ public unsafe static extern int llama_model_quantize(string fname_inp, string fname_out, LLamaModelQuantizeParams* param);
}
}
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index e5dd9dfb..4540a123 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -10,6 +10,7 @@ namespace LLama.Native
using llama_token = Int32;
public unsafe partial class NativeApi
{
+ public static readonly int LLAMA_MAX_DEVICES = 1;
static NativeApi()
{
try
@@ -34,6 +35,9 @@ namespace LLama.Native
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaContextParams llama_context_default_params();
+ [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
+ public static extern LLamaModelQuantizeParams llama_model_quantize_default_params();
+
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mmap_supported();