Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doubleml/plm/tests/test_lplr_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def test_lplr_exception_learner():
def test_lplr_exception_and_warning_learner():
# invalid ml_M (must be a classifier with predict_proba)
with pytest.raises(TypeError):
_ = DoubleMLLPLR(dml_data, _DummyNoClassifier(), ml_t, ml_m)
_ = DoubleMLLPLR(dml_data, Lasso(), ml_t, ml_m)
msg = "Invalid learner provided for ml_M: " + r"Lasso\(\) has no method .predict_proba\(\)."
with pytest.raises(TypeError, match=msg):
_ = DoubleMLLPLR(dml_data, Lasso(), ml_t, ml_m)
Expand Down
6 changes: 3 additions & 3 deletions doubleml/tests/test_optuna_multi_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _build_apos_object():
dml_data = dml.DoubleMLData(df, "y", "d")

ml_g = LinearRegression()
ml_m = LogisticRegression(max_iter=200, multi_class="auto")
ml_m = LogisticRegression(max_iter=200)

return dml.DoubleMLAPOS(
dml_data,
Expand All @@ -38,7 +38,7 @@ def _build_apos_object():
def _build_qte_object():
np.random.seed(3141)
dml_data = make_irm_data(n_obs=80, dim_x=5)
ml = LogisticRegression(max_iter=200, multi_class="auto")
ml = LogisticRegression(max_iter=200)

return dml.DoubleMLQTE(
dml_data,
Expand All @@ -57,7 +57,7 @@ def _build_did_multi_object():
dml_panel = DoubleMLPanelData(df, y_col="y", d_cols="d", t_col="t", id_col="id", x_cols=x_cols)

ml_g = LinearRegression()
ml_m = LogisticRegression(max_iter=200, multi_class="auto")
ml_m = LogisticRegression(max_iter=200)

return dml.did.DoubleMLDIDMulti(
obj_dml_data=dml_panel,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dev = [
"pytest>=8.3.0",
"pytest-cov>=6.0.0",
"xgboost>=2.1.0",
"lightgbm>=4.5.0",
"lightgbm>=4.6.0",
"black>=25.1.0",
"ruff>=0.11.1",
"pre-commit>=4.2.0",
Expand Down
Loading