import itertools import re import time import warnings import joblib import numpy as np import pytest from numpy.testing import assert_array_equal from sklearn import config_context, get_config from sklearn.compose import make_column_transformer from sklearn.datasets import load_iris from sklearn.ensemble import RandomForestClassifier from sklearn.exceptions import ConvergenceWarning from sklearn.model_selection import GridSearchCV from sklearn.pipeline import make_pipeline from sklearn.preprocessing import StandardScaler from sklearn.utils.fixes import _IS_WASM from sklearn.utils.parallel import Parallel, delayed def get_working_memory(): return get_config()["working_memory"] @pytest.mark.parametrize("n_jobs", [1, 2]) @pytest.mark.parametrize("backend", ["loky", "threading", "multiprocessing"]) def test_configuration_passes_through_to_joblib(n_jobs, backend): # Tests that the global global configuration is passed to joblib jobs with config_context(working_memory=123): results = Parallel(n_jobs=n_jobs, backend=backend)( delayed(get_working_memory)() for _ in range(2) ) assert_array_equal(results, [123] * 2) def test_parallel_delayed_warnings(): """Informative warnings should be raised when mixing sklearn and joblib API""" # We should issue a warning when one wants to use sklearn.utils.fixes.Parallel # with joblib.delayed. The config will not be propagated to the workers. warn_msg = "`sklearn.utils.parallel.Parallel` needs to be used in conjunction" with pytest.warns(UserWarning, match=warn_msg) as records: Parallel()(joblib.delayed(time.sleep)(0) for _ in range(10)) assert len(records) == 10 # We should issue a warning if one wants to use sklearn.utils.fixes.delayed with # joblib.Parallel warn_msg = ( "`sklearn.utils.parallel.delayed` should be used with " "`sklearn.utils.parallel.Parallel` to make it possible to propagate" ) with pytest.warns(UserWarning, match=warn_msg) as records: joblib.Parallel()(delayed(time.sleep)(0) for _ in range(10)) assert len(records) == 10 @pytest.mark.parametrize("n_jobs", [1, 2]) def test_dispatch_config_parallel(n_jobs): """Check that we properly dispatch the configuration in parallel processing. Non-regression test for: https://github.com/scikit-learn/scikit-learn/issues/25239 """ pd = pytest.importorskip("pandas") iris = load_iris(as_frame=True) class TransformerRequiredDataFrame(StandardScaler): def fit(self, X, y=None): assert isinstance(X, pd.DataFrame), "X should be a DataFrame" return super().fit(X, y) def transform(self, X, y=None): assert isinstance(X, pd.DataFrame), "X should be a DataFrame" return super().transform(X, y) dropper = make_column_transformer( ("drop", [0]), remainder="passthrough", n_jobs=n_jobs, ) param_grid = {"randomforestclassifier__max_depth": [1, 2, 3]} search_cv = GridSearchCV( make_pipeline( dropper, TransformerRequiredDataFrame(), RandomForestClassifier(n_estimators=5, n_jobs=n_jobs), ), param_grid, cv=5, n_jobs=n_jobs, error_score="raise", # this search should not fail ) # make sure that `fit` would fail in case we don't request dataframe with pytest.raises(AssertionError, match="X should be a DataFrame"): search_cv.fit(iris.data, iris.target) with config_context(transform_output="pandas"): # we expect each intermediate steps to output a DataFrame search_cv.fit(iris.data, iris.target) assert not np.isnan(search_cv.cv_results_["mean_test_score"]).any() def raise_warning(): warnings.warn("Convergence warning", ConvergenceWarning) def _yield_n_jobs_backend_combinations(): n_jobs_values = [1, 2] backend_values = ["loky", "threading", "multiprocessing"] for n_jobs, backend in itertools.product(n_jobs_values, backend_values): if n_jobs == 2 and backend == "loky": # XXX Mark thread-unsafe to avoid: # RuntimeError: The executor underlying Parallel has been shutdown. # See https://github.com/joblib/joblib/issues/1743 for more details. yield pytest.param(n_jobs, backend, marks=pytest.mark.thread_unsafe) else: yield n_jobs, backend @pytest.mark.parametrize("n_jobs, backend", _yield_n_jobs_backend_combinations()) def test_filter_warning_propagates(n_jobs, backend): """Check warning propagates to the job.""" with warnings.catch_warnings(): warnings.simplefilter("error", category=ConvergenceWarning) with pytest.raises(ConvergenceWarning): Parallel(n_jobs=n_jobs, backend=backend)( delayed(raise_warning)() for _ in range(2) ) def get_warning_filters(): # In free-threading Python >= 3.14, warnings filters are managed through a # ContextVar and warnings.filters is not modified inside a # warnings.catch_warnings context. You need to use warnings._get_filters(). # For more details, see # https://docs.python.org/3.14/whatsnew/3.14.html#concurrent-safe-warnings-control filters_func = getattr(warnings, "_get_filters", None) return filters_func() if filters_func is not None else warnings.filters def test_check_warnings_threading(): """Check that warnings filters are set correctly in the threading backend.""" with warnings.catch_warnings(): warnings.simplefilter("error", category=ConvergenceWarning) main_warning_filters = get_warning_filters() assert ("error", None, ConvergenceWarning, None, 0) in main_warning_filters all_worker_warning_filters = Parallel(n_jobs=2, backend="threading")( delayed(get_warning_filters)() for _ in range(2) ) def normalize_main_module(filters): # In Python 3.14 free-threaded, there is a small discrepancy main # warning filters have an entry with module = "__main__" whereas it # is a regex in the workers return [ ( action, message, type_, module if "__main__" not in str(module) or not isinstance(module, re.Pattern) else module.pattern, lineno, ) for action, message, type_, module, lineno in main_warning_filters ] for worker_warning_filter in all_worker_warning_filters: assert normalize_main_module( worker_warning_filter ) == normalize_main_module(main_warning_filters) @pytest.mark.xfail(_IS_WASM, reason="Pyodide always use the sequential backend") def test_filter_warning_propagates_no_side_effect_with_loky_backend(): with warnings.catch_warnings(): warnings.simplefilter("error", category=ConvergenceWarning) Parallel(n_jobs=2, backend="loky")(delayed(time.sleep)(0) for _ in range(10)) # Since loky workers are reused, make sure that inside the loky workers, # warnings filters have been reset to their original value. Using joblib # directly should not turn ConvergenceWarning into an error. joblib.Parallel(n_jobs=2, backend="loky")( joblib.delayed(warnings.warn)("Convergence warning", ConvergenceWarning) for _ in range(10) )