""" CUDA Runtime wrapper. This provides a very minimal set of bindings, since the Runtime API is not really used in Numba except for querying the Runtime version. """ import ctypes import functools import sys from numba.core import config from numba.cuda.cudadrv.driver import ERROR_MAP, make_logger from numba.cuda.cudadrv.error import CudaSupportError, CudaRuntimeError from numba.cuda.cudadrv.libs import open_cudalib from numba.cuda.cudadrv.rtapi import API_PROTOTYPES from numba.cuda.cudadrv import enums class CudaRuntimeAPIError(CudaRuntimeError): """ Raised when there is an error accessing a C API from the CUDA Runtime. """ def __init__(self, code, msg): self.code = code self.msg = msg super().__init__(code, msg) def __str__(self): return "[%s] %s" % (self.code, self.msg) class Runtime: """ Runtime object that lazily binds runtime API functions. """ def __init__(self): self.is_initialized = False def _initialize(self): # lazily initialize logger global _logger _logger = make_logger() if config.DISABLE_CUDA: msg = ("CUDA is disabled due to setting NUMBA_DISABLE_CUDA=1 " "in the environment, or because CUDA is unsupported on " "32-bit systems.") raise CudaSupportError(msg) self.lib = open_cudalib('cudart') self.is_initialized = True def __getattr__(self, fname): # First request of a runtime API function try: proto = API_PROTOTYPES[fname] except KeyError: raise AttributeError(fname) restype = proto[0] argtypes = proto[1:] if not self.is_initialized: self._initialize() # Find function in runtime library libfn = self._find_api(fname) libfn.restype = restype libfn.argtypes = argtypes safe_call = self._wrap_api_call(fname, libfn) setattr(self, fname, safe_call) return safe_call def _wrap_api_call(self, fname, libfn): @functools.wraps(libfn) def safe_cuda_api_call(*args): _logger.debug('call runtime api: %s', libfn.__name__) retcode = libfn(*args) self._check_error(fname, retcode) return safe_cuda_api_call def _check_error(self, fname, retcode): if retcode != enums.CUDA_SUCCESS: errname = ERROR_MAP.get(retcode, "cudaErrorUnknown") msg = "Call to %s results in %s" % (fname, errname) _logger.error(msg) raise CudaRuntimeAPIError(retcode, msg) def _find_api(self, fname): try: return getattr(self.lib, fname) except AttributeError: pass # Not found. # Delay missing function error to use def absent_function(*args, **kws): msg = "runtime missing function: %s." raise CudaRuntimeError(msg % fname) setattr(self, fname, absent_function) return absent_function def get_version(self): """ Returns the CUDA Runtime version as a tuple (major, minor). """ rtver = ctypes.c_int() self.cudaRuntimeGetVersion(ctypes.byref(rtver)) # The version is encoded as (1000 * major) + (10 * minor) major = rtver.value // 1000 minor = (rtver.value - (major * 1000)) // 10 return (major, minor) def is_supported_version(self): """ Returns True if the CUDA Runtime is a supported version. """ return self.get_version() in self.supported_versions @property def supported_versions(self): """A tuple of all supported CUDA toolkit versions. Versions are given in the form ``(major_version, minor_version)``.""" if sys.platform not in ('linux', 'win32') or config.MACHINE_BITS != 64: # Only 64-bit Linux and Windows are supported return () return ((9, 2), (10, 0), (10, 1), (10, 2), (11, 0), (11, 1), (11, 2)) runtime = Runtime()