#=============================================================================== # Copyright 2021 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #=============================================================================== import threading from contextlib import contextmanager from sklearn import get_config as skl_get_config from sklearn import set_config as skl_set_config _default_global_config = { "target_offload": "auto", "allow_fallback_to_host": False, } _threadlocal = threading.local() def _get_sklearnex_threadlocal_config(): if not hasattr(_threadlocal, "global_config"): _threadlocal.global_config = _default_global_config.copy() return _threadlocal.global_config def get_config(): """Retrieve current values for configuration set by :func:`set_config` Returns ------- config : dict Keys are parameter names that can be passed to :func:`set_config`. See Also -------- config_context : Context manager for global configuration. set_config : Set global configuration. """ sklearn = skl_get_config() sklearnex = _get_sklearnex_threadlocal_config().copy() return {**sklearn, **sklearnex} def set_config(target_offload=None, allow_fallback_to_host=None, **sklearn_configs): """Set global configuration Parameters ---------- target_offload : string or dpctl.SyclQueue, default=None The device primarily used to perform computations. If string, expected to be "auto" (the execution context is deduced from input data location), or SYCL* filter selector string. Global default: "auto". allow_fallback_to_host : bool, default=None If True, allows to fallback computation to host device in case particular estimator does not support the selected one. Global default: False. See Also -------- config_context : Context manager for global configuration. get_config : Retrieve current values of the global configuration. """ skl_set_config(**sklearn_configs) local_config = _get_sklearnex_threadlocal_config() if target_offload is not None: local_config["target_offload"] = target_offload if allow_fallback_to_host is not None: local_config["allow_fallback_to_host"] = allow_fallback_to_host @contextmanager def config_context(**new_config): """Context manager for global scikit-learn configuration Parameters ---------- target_offload : string or dpctl.SyclQueue, default=None The device primarily used to perform computations. If string, expected to be "auto" (the execution context is deduced from input data location), or SYCL* filter selector string. Global default: "auto". allow_fallback_to_host : bool, default=None If True, allows to fallback computation to host device in case particular estimator does not support the selected one. Global default: False. Notes ----- All settings, not just those presently modified, will be returned to their previous values when the context manager is exited. See Also -------- set_config : Set global scikit-learn configuration. get_config : Retrieve current values of the global configuration. """ old_config = get_config() set_config(**new_config) try: yield finally: set_config(**old_config)