using LLama.Abstractions; using System.Collections.Generic; using System.Runtime.InteropServices; namespace LLama.Native { #if NET6_0_OR_GREATER /// /// A native library compiled with cublas/cuda. /// public class NativeLibraryWithCuda : INativeLibrary { private int _majorCudaVersion; private NativeLibraryName _libraryName; private AvxLevel _avxLevel; private bool _skipCheck; /// public NativeLibraryMetadata? Metadata { get { return new NativeLibraryMetadata(_libraryName, true, _avxLevel); } } /// /// /// /// /// /// public NativeLibraryWithCuda(int majorCudaVersion, NativeLibraryName libraryName, bool skipCheck) { _majorCudaVersion = majorCudaVersion; _libraryName = libraryName; _skipCheck = skipCheck; } /// public IEnumerable Prepare(SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback) { // TODO: Avx level is ignored now, needs to be implemented in the future. if (systemInfo.OSPlatform == OSPlatform.Windows || systemInfo.OSPlatform == OSPlatform.Linux || _skipCheck) { if (_majorCudaVersion == -1 && _skipCheck) { // Currently only 11 and 12 are supported. var cuda12LibraryPath = GetCudaPath(systemInfo, 12, logCallback); if (cuda12LibraryPath is not null) { yield return cuda12LibraryPath; } var cuda11LibraryPath = GetCudaPath(systemInfo, 11, logCallback); if (cuda11LibraryPath is not null) { yield return cuda11LibraryPath; } } else if (_majorCudaVersion != -1) { var cudaLibraryPath = GetCudaPath(systemInfo, _majorCudaVersion, logCallback); if (cudaLibraryPath is not null) { yield return cudaLibraryPath; } } } } private string? GetCudaPath(SystemInfo systemInfo, int cudaVersion, NativeLogConfig.LLamaLogCallback? logCallback) { NativeLibraryUtils.GetPlatformPathParts(systemInfo.OSPlatform, out var os, out var fileExtension, out var libPrefix); var relativePath = $"runtimes/{os}/native/cuda{cudaVersion}/{libPrefix}{_libraryName.GetLibraryName()}{fileExtension}"; return relativePath; } } #endif }