#!/usr/bin/env python #=============================================================================== # Copyright 2021 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. #=============================================================================== def get_patch_str(name=None, verbose=True): return f"""try: from sklearnex import patch_sklearn patch_sklearn(name={str(name)}, verbose={str(verbose)}) del patch_sklearn except ImportError: pass""" def get_patch_str_re(): return r"""\ntry: from sklearnex import patch_sklearn patch_sklearn\(name=.*, verbose=.*\) del patch_sklearn except ImportError: pass\n""" def patch_sklearn_global(name=None, verbose=True): import os import re try: import sklearn except ImportError: raise ImportError("Scikit-learn could not be imported. Nothing to patch\n") init_file_path = sklearn.__file__ distributor_file_path = os.path.join(os.path.dirname(init_file_path), "_distributor_init.py") with open(distributor_file_path, 'r', encoding='utf-8') as distributor_file: lines = distributor_file.read() if re.search(get_patch_str_re(), lines): lines = re.sub(get_patch_str_re(), '', lines) with open(distributor_file_path, 'w', encoding='utf-8') as distributor_file: distributor_file.write(lines + "\n" + get_patch_str(name, verbose) + "\n") print("Scikit-learn was successfully globally patched" " by Intel(R) Extension for Scikit-learn") return def unpatch_sklearn_global(): import os import re try: import sklearn except ImportError: raise ImportError("Scikit-learn could not be imported. Nothing to unpatch\n") init_file_path = sklearn.__file__ distributor_file_path = os.path.join(os.path.dirname(init_file_path), "_distributor_init.py") with open(distributor_file_path, 'r', encoding='utf-8') as distributor_file: lines = distributor_file.read() if not re.search(get_patch_str_re(), lines): print("Nothing to unpatch: Scikit-learn is not patched\n") return lines = re.sub(get_patch_str_re(), '', lines) with open(distributor_file_path, 'w', encoding='utf-8') as distributor_file: distributor_file.write(lines) print("Scikit-learn was successfully globally unpatched")