Skip to content

Commit 3fd4146

Browse files
authored
Handle statsmodels PerfectSeparationWarning on 0.14.0+ (#3356)
1 parent a48601d commit 3fd4146

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

seaborn/regression.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,20 @@ def reg_func(_x, _y):
264264

265265
def fit_statsmodels(self, grid, model, **kwargs):
266266
"""More general regression function using statsmodels objects."""
267-
import statsmodels.genmod.generalized_linear_model as glm
267+
import statsmodels.tools.sm_exceptions as sme
268268
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
269269
grid = np.c_[np.ones(len(grid)), grid]
270270

271271
def reg_func(_x, _y):
272+
err_classes = (sme.PerfectSeparationError,)
272273
try:
273-
yhat = model(_y, _x, **kwargs).fit().predict(grid)
274-
except glm.PerfectSeparationError:
274+
with warnings.catch_warnings():
275+
if hasattr(sme, "PerfectSeparationWarning"):
276+
# statsmodels>=0.14.0
277+
warnings.simplefilter("error", sme.PerfectSeparationWarning)
278+
err_classes = (*err_classes, sme.PerfectSeparationWarning)
279+
yhat = model(_y, _x, **kwargs).fit().predict(grid)
280+
except err_classes:
275281
yhat = np.empty(len(grid))
276282
yhat.fill(np.nan)
277283
return yhat

0 commit comments

Comments
 (0)