diff --git a/doubleml/plm/tests/test_lplr_exceptions.py b/doubleml/plm/tests/test_lplr_exceptions.py index 404770fa..c58d7aa0 100644 --- a/doubleml/plm/tests/test_lplr_exceptions.py +++ b/doubleml/plm/tests/test_lplr_exceptions.py @@ -241,7 +241,7 @@ def test_lplr_exception_learner(): def test_lplr_exception_and_warning_learner(): # invalid ml_M (must be a classifier with predict_proba) with pytest.raises(TypeError): - _ = DoubleMLLPLR(dml_data, _DummyNoClassifier(), ml_t, ml_m) + _ = DoubleMLLPLR(dml_data, Lasso(), ml_t, ml_m) msg = "Invalid learner provided for ml_M: " + r"Lasso\(\) has no method .predict_proba\(\)." with pytest.raises(TypeError, match=msg): _ = DoubleMLLPLR(dml_data, Lasso(), ml_t, ml_m) diff --git a/doubleml/tests/test_optuna_multi_wrappers.py b/doubleml/tests/test_optuna_multi_wrappers.py index d484b4ef..332fd6cb 100644 --- a/doubleml/tests/test_optuna_multi_wrappers.py +++ b/doubleml/tests/test_optuna_multi_wrappers.py @@ -23,7 +23,7 @@ def _build_apos_object(): dml_data = dml.DoubleMLData(df, "y", "d") ml_g = LinearRegression() - ml_m = LogisticRegression(max_iter=200, multi_class="auto") + ml_m = LogisticRegression(max_iter=200) return dml.DoubleMLAPOS( dml_data, @@ -38,7 +38,7 @@ def _build_apos_object(): def _build_qte_object(): np.random.seed(3141) dml_data = make_irm_data(n_obs=80, dim_x=5) - ml = LogisticRegression(max_iter=200, multi_class="auto") + ml = LogisticRegression(max_iter=200) return dml.DoubleMLQTE( dml_data, @@ -57,7 +57,7 @@ def _build_did_multi_object(): dml_panel = DoubleMLPanelData(df, y_col="y", d_cols="d", t_col="t", id_col="id", x_cols=x_cols) ml_g = LinearRegression() - ml_m = LogisticRegression(max_iter=200, multi_class="auto") + ml_m = LogisticRegression(max_iter=200) return dml.did.DoubleMLDIDMulti( obj_dml_data=dml_panel, diff --git a/pyproject.toml b/pyproject.toml index a3fe2414..d250d494 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dev = [ "pytest>=8.3.0", "pytest-cov>=6.0.0", "xgboost>=2.1.0", - "lightgbm>=4.5.0", + "lightgbm>=4.6.0", "black>=25.1.0", "ruff>=0.11.1", "pre-commit>=4.2.0",