# Copyright (C) 2012 Anaconda, Inc # SPDX-License-Identifier: BSD-3-Clause """Detect CUDA version.""" import ctypes import functools import itertools import multiprocessing import os import platform from contextlib import suppress from .. import CondaVirtualPackage, hookimpl def cuda_version(): """ Attempt to detect the version of CUDA present in the operating system. On Windows and Linux, the CUDA library is installed by the NVIDIA driver package, and is typically found in the standard library path, rather than with the CUDA SDK (which is optional for running CUDA apps). On macOS, the CUDA library is only installed with the CUDA SDK, and might not be in the library path. Returns: version string (e.g., '9.2') or None if CUDA is not found. """ if "CONDA_OVERRIDE_CUDA" in os.environ: return os.environ["CONDA_OVERRIDE_CUDA"].strip() or None # Do not inherit file descriptors and handles from the parent process. # The `fork` start method should be considered unsafe as it can lead to # crashes of the subprocess. The `spawn` start method is preferred. context = multiprocessing.get_context("spawn") queue = context.SimpleQueue() try: # Spawn a subprocess to detect the CUDA version detector = context.Process( target=_cuda_driver_version_detector_target, args=(queue,), name="CUDA driver version detector", daemon=True, ) detector.start() detector.join(timeout=60.0) finally: # Always cleanup the subprocess detector.kill() # requires Python 3.7+ if queue.empty(): return None result = queue.get() return result @functools.lru_cache(maxsize=None) def cached_cuda_version(): """A cached version of the cuda detection system.""" return cuda_version() @hookimpl def conda_virtual_packages(): cuda_version = cached_cuda_version() if cuda_version is not None: yield CondaVirtualPackage("cuda", cuda_version, None) def _cuda_driver_version_detector_target(queue): """ Attempt to detect the version of CUDA present in the operating system in a subprocess. On Windows and Linux, the CUDA library is installed by the NVIDIA driver package, and is typically found in the standard library path, rather than with the CUDA SDK (which is optional for running CUDA apps). On macOS, the CUDA library is only installed with the CUDA SDK, and might not be in the library path. Returns: version string (e.g., '9.2') or None if CUDA is not found. The result is put in the queue rather than a return value. """ # Platform-specific libcuda location system = platform.system() if system == "Darwin": lib_filenames = [ "libcuda.1.dylib", # check library path first "libcuda.dylib", "/usr/local/cuda/lib/libcuda.1.dylib", "/usr/local/cuda/lib/libcuda.dylib", ] elif system == "Linux": lib_filenames = [ "libcuda.so", # check library path first "/usr/lib64/nvidia/libcuda.so", # RHEL/Centos/Fedora "/usr/lib/x86_64-linux-gnu/libcuda.so", # Ubuntu "/usr/lib/wsl/lib/libcuda.so", # WSL ] # Also add libraries with version suffix `.1` lib_filenames = list( itertools.chain.from_iterable((f"{lib}.1", lib) for lib in lib_filenames) ) elif system == "Windows": bits = platform.architecture()[0].replace("bit", "") # e.g. "64" or "32" lib_filenames = [f"nvcuda{bits}.dll", "nvcuda.dll"] else: queue.put(None) # CUDA not available for other operating systems return # Open library if system == "Windows": dll = ctypes.windll else: dll = ctypes.cdll for lib_filename in lib_filenames: with suppress(Exception): libcuda = dll.LoadLibrary(lib_filename) break else: queue.put(None) return # Empty `CUDA_VISIBLE_DEVICES` can cause `cuInit()` returns `CUDA_ERROR_NO_DEVICE` # Invalid `CUDA_VISIBLE_DEVICES` can cause `cuInit()` returns `CUDA_ERROR_INVALID_DEVICE` # Unset this environment variable to avoid these errors os.environ.pop("CUDA_VISIBLE_DEVICES", None) # Get CUDA version try: cuInit = libcuda.cuInit flags = ctypes.c_uint(0) ret = cuInit(flags) if ret != 0: queue.put(None) return cuDriverGetVersion = libcuda.cuDriverGetVersion version_int = ctypes.c_int(0) ret = cuDriverGetVersion(ctypes.byref(version_int)) if ret != 0: queue.put(None) return # Convert version integer to version string value = version_int.value queue.put(f"{value // 1000}.{(value % 1000) // 10}") return except Exception: queue.put(None) return