using LLama.Abstractions;
using System.Collections.Generic;
using System.Runtime.InteropServices;
namespace LLama.Native
{
#if NET6_0_OR_GREATER
///
/// A native library compiled with cublas/cuda.
///
public class NativeLibraryWithCuda : INativeLibrary
{
private int _majorCudaVersion;
private NativeLibraryName _libraryName;
private AvxLevel _avxLevel;
private bool _skipCheck;
///
public NativeLibraryMetadata? Metadata
{
get
{
return new NativeLibraryMetadata(_libraryName, true, _avxLevel);
}
}
///
///
///
///
///
///
public NativeLibraryWithCuda(int majorCudaVersion, NativeLibraryName libraryName, bool skipCheck)
{
_majorCudaVersion = majorCudaVersion;
_libraryName = libraryName;
_skipCheck = skipCheck;
}
///
public IEnumerable Prepare(SystemInfo systemInfo, NativeLogConfig.LLamaLogCallback? logCallback)
{
// TODO: Avx level is ignored now, needs to be implemented in the future.
if (systemInfo.OSPlatform == OSPlatform.Windows || systemInfo.OSPlatform == OSPlatform.Linux || _skipCheck)
{
if (_majorCudaVersion == -1 && _skipCheck)
{
// Currently only 11 and 12 are supported.
var cuda12LibraryPath = GetCudaPath(systemInfo, 12, logCallback);
if (cuda12LibraryPath is not null)
{
yield return cuda12LibraryPath;
}
var cuda11LibraryPath = GetCudaPath(systemInfo, 11, logCallback);
if (cuda11LibraryPath is not null)
{
yield return cuda11LibraryPath;
}
}
else if (_majorCudaVersion != -1)
{
var cudaLibraryPath = GetCudaPath(systemInfo, _majorCudaVersion, logCallback);
if (cudaLibraryPath is not null)
{
yield return cudaLibraryPath;
}
}
}
}
private string? GetCudaPath(SystemInfo systemInfo, int cudaVersion, NativeLogConfig.LLamaLogCallback? logCallback)
{
NativeLibraryUtils.GetPlatformPathParts(systemInfo.OSPlatform, out var os, out var fileExtension, out var libPrefix);
var relativePath = $"runtimes/{os}/native/cuda{cudaVersion}/{libPrefix}{_libraryName.GetLibraryName()}{fileExtension}";
return relativePath;
}
}
#endif
}