Skip to content

Commit 6373ff5

Browse files
authored
Merge pull request #354 from KhiopsML/support-float-and-bool-targets-in-khiops-classifiers
Support float and bool targets in khiops classifiers
2 parents 2d9946b + e7674ec commit 6373ff5

File tree

4 files changed

+92
-10
lines changed

4 files changed

+92
-10
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
- Example: 10.2.1.4 is the 5th version that supports khiops 10.2.1.
77
- Internals: Changes in *Internals* sections are unlikely to be of interest for data scientists.
88

9+
## Unreleased
10+
11+
### Added
12+
- (`sklearn`) Support for boolean and float targets in `KhiopsClassifier`.
13+
914
## 10.3.0.0 - 2025-02-10
1015

1116
### Fixed

khiops/sklearn/dataset.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numpy as np
1616
import pandas as pd
17+
import sklearn
1718
from scipy import sparse as sp
1819
from sklearn.utils import check_array
1920
from sklearn.utils.validation import column_or_1d
@@ -430,6 +431,19 @@ def write_internal_data_table(dataframe, file_path_or_stream):
430431
)
431432

432433

434+
def _column_or_1d_with_dtype(y, dtype=None):
435+
# 'dtype' has been introduced on `column_or_1d' since Scikit-learn 1.2;
436+
if sklearn.__version__ < "1.2":
437+
if pd.api.types.is_string_dtype(dtype) and y.isin(["True", "False"]).all():
438+
warnings.warn(
439+
"'y' stores strings restricted to 'True'/'False' values: "
440+
"The predict method may return a bool vector."
441+
)
442+
return column_or_1d(y, warn=True)
443+
else:
444+
return column_or_1d(y, warn=True, dtype=dtype)
445+
446+
433447
class Dataset:
434448
"""A representation of a dataset
435449
@@ -738,8 +752,22 @@ def _init_target_column(self, y):
738752
if isinstance(y, str):
739753
y_checked = y
740754
else:
741-
y_checked = column_or_1d(y, warn=True)
742-
755+
if hasattr(y, "dtype"):
756+
if isinstance(y.dtype, pd.CategoricalDtype):
757+
y_checked = _column_or_1d_with_dtype(
758+
y, dtype=y.dtype.categories.dtype
759+
)
760+
else:
761+
y_checked = _column_or_1d_with_dtype(y, dtype=y.dtype)
762+
elif hasattr(y, "dtypes"):
763+
if isinstance(y.dtypes[0], pd.CategoricalDtype):
764+
y_checked = _column_or_1d_with_dtype(
765+
y, dtype=y.dtypes[0].categories.dtype
766+
)
767+
else:
768+
y_checked = _column_or_1d_with_dtype(y)
769+
else:
770+
y_checked = _column_or_1d_with_dtype(y)
743771
# Check the target type coherence with those of X's tables
744772
if isinstance(
745773
self.main_table, (PandasTable, SparseTable, NumpyTable)

khiops/sklearn/estimators.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def _check_categorical_target_type(ds):
154154
or pd.api.types.is_string_dtype(ds.target_column.dtype)
155155
or pd.api.types.is_integer_dtype(ds.target_column.dtype)
156156
or pd.api.types.is_float_dtype(ds.target_column.dtype)
157+
or pd.api.types.is_bool_dtype(ds.target_column.dtype)
157158
):
158159
raise ValueError(
159160
f"'y' has invalid type '{ds.target_column_type}'. "
@@ -2123,6 +2124,24 @@ def _is_real_target_dtype_integer(self):
21232124
)
21242125
)
21252126

2127+
def _is_real_target_dtype_float(self):
2128+
return self._original_target_dtype is not None and (
2129+
pd.api.types.is_float_dtype(self._original_target_dtype)
2130+
or (
2131+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2132+
and pd.api.types.is_float_dtype(self._original_target_dtype.categories)
2133+
)
2134+
)
2135+
2136+
def _is_real_target_dtype_bool(self):
2137+
return self._original_target_dtype is not None and (
2138+
pd.api.types.is_bool_dtype(self._original_target_dtype)
2139+
or (
2140+
isinstance(self._original_target_dtype, pd.CategoricalDtype)
2141+
and pd.api.types.is_bool_dtype(self._original_target_dtype.categories)
2142+
)
2143+
)
2144+
21262145
def _sorted_prob_variable_names(self):
21272146
"""Returns the model probability variable names in the order of self.classes_"""
21282147
self._assert_is_fitted()
@@ -2227,8 +2246,13 @@ def _fit_training_post_process(self, ds):
22272246
for key in variable.meta_data.keys:
22282247
if key.startswith("TargetProb"):
22292248
self.classes_.append(variable.meta_data.get_value(key))
2230-
if ds.is_in_memory and self._is_real_target_dtype_integer():
2231-
self.classes_ = [int(class_value) for class_value in self.classes_]
2249+
if ds.is_in_memory:
2250+
if self._is_real_target_dtype_integer():
2251+
self.classes_ = [int(class_value) for class_value in self.classes_]
2252+
elif self._is_real_target_dtype_float():
2253+
self.classes_ = [float(class_value) for class_value in self.classes_]
2254+
elif self._is_real_target_dtype_bool():
2255+
self.classes_ = [class_value == "True" for class_value in self.classes_]
22322256
self.classes_.sort()
22332257
self.classes_ = column_or_1d(self.classes_)
22342258

@@ -2283,9 +2307,10 @@ def predict(self, X):
22832307
-------
22842308
`ndarray <numpy.ndarray>`
22852309
An array containing the encoded columns. A first column containing key
2286-
column ids is added in multi-table mode. The `numpy.dtype` of the array is
2287-
integer if the classifier was learned with an integer ``y``. Otherwise it
2288-
will be ``str``.
2310+
column ids is added in multi-table mode. The `numpy.dtype` of the array
2311+
matches the type of ``y`` used during training. It will be integer, float,
2312+
or boolean if the classifier was trained with a ``y`` of the corresponding
2313+
type. Otherwise it will be ``str``.
22892314
22902315
The key columns are added for multi-table tasks.
22912316
"""

tests/test_sklearn_output_types.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,12 @@ def test_classifier_output_types(self):
7171
khc = KhiopsClassifier(n_trees=0)
7272
khc.fit(X, y)
7373
y_pred = khc.predict(X)
74+
khc.fit(X_mt, y)
75+
y_mt_pred = khc.predict(X_mt)
76+
7477
y_bin = y.replace({0: 0, 1: 0, 2: 1})
7578
khc.fit(X, y_bin)
7679
y_bin_pred = khc.predict(X)
77-
khc.fit(X_mt, y)
78-
khc.export_report_file("report.khj")
79-
y_mt_pred = khc.predict(X_mt)
8080
khc.fit(X_mt, y_bin)
8181
y_mt_bin_pred = khc.predict(X_mt)
8282

@@ -85,6 +85,8 @@ def test_classifier_output_types(self):
8585
"ys": {
8686
"int": y,
8787
"int binary": y_bin,
88+
"float": y.astype(float),
89+
"bool": y.replace({0: True, 1: True, 2: False}),
8890
"string": self._replace(y, {0: "se", 1: "vi", 2: "ve"}),
8991
"string binary": self._replace(y_bin, {0: "vi_or_se", 1: "ve"}),
9092
"int as string": self._replace(y, {0: "8", 1: "9", 2: "10"}),
@@ -93,30 +95,42 @@ def test_classifier_output_types(self):
9395
"cat string": pd.Series(
9496
self._replace(y, {0: "se", 1: "vi", 2: "ve"})
9597
).astype("category"),
98+
"cat float": y.astype(float).astype("category"),
99+
"cat bool": y.replace({0: True, 1: True, 2: False}).astype("category"),
96100
},
97101
"y_type_check": {
98102
"int": pd.api.types.is_integer_dtype,
99103
"int binary": pd.api.types.is_integer_dtype,
104+
"float": pd.api.types.is_float_dtype,
105+
"bool": pd.api.types.is_bool_dtype,
100106
"string": pd.api.types.is_string_dtype,
101107
"string binary": pd.api.types.is_string_dtype,
102108
"int as string": pd.api.types.is_string_dtype,
103109
"int as string binary": pd.api.types.is_string_dtype,
104110
"cat int": pd.api.types.is_integer_dtype,
105111
"cat string": pd.api.types.is_string_dtype,
112+
"cat float": pd.api.types.is_float_dtype,
113+
"cat bool": pd.api.types.is_bool_dtype,
106114
},
107115
"expected_classes": {
108116
"int": column_or_1d([0, 1, 2]),
109117
"int binary": column_or_1d([0, 1]),
118+
"float": column_or_1d([0.0, 1.0, 2.0]),
119+
"bool": column_or_1d([False, True]),
110120
"string": column_or_1d(["se", "ve", "vi"]),
111121
"string binary": column_or_1d(["ve", "vi_or_se"]),
112122
"int as string": column_or_1d(["10", "8", "9"]),
113123
"int as string binary": column_or_1d(["10", "89"]),
114124
"cat int": column_or_1d([0, 1, 2]),
115125
"cat string": column_or_1d(["se", "ve", "vi"]),
126+
"cat float": column_or_1d([0.0, 1.0, 2.0]),
127+
"cat bool": column_or_1d([False, True]),
116128
},
117129
"expected_y_preds": {
118130
"mono": {
119131
"int": y_pred,
132+
"float": y_pred.astype(float),
133+
"bool": self._replace(y_bin_pred, {0: True, 1: False}),
120134
"int binary": y_bin_pred,
121135
"string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
122136
"string binary": self._replace(
@@ -128,9 +142,15 @@ def test_classifier_output_types(self):
128142
),
129143
"cat int": y_pred,
130144
"cat string": self._replace(y_pred, {0: "se", 1: "vi", 2: "ve"}),
145+
"cat float": self._replace(
146+
y_pred, {target: float(target) for target in (0, 1, 2)}
147+
),
148+
"cat bool": self._replace(y_bin_pred, {0: True, 1: False}),
131149
},
132150
"multi": {
133151
"int": y_mt_pred,
152+
"float": y_mt_pred.astype(float),
153+
"bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
134154
"int binary": y_mt_bin_pred,
135155
"string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
136156
"string binary": self._replace(
@@ -144,6 +164,10 @@ def test_classifier_output_types(self):
144164
),
145165
"cat int": y_mt_pred,
146166
"cat string": self._replace(y_mt_pred, {0: "se", 1: "vi", 2: "ve"}),
167+
"cat float": self._replace(
168+
y_mt_pred, {target: float(target) for target in (0, 1, 2)}
169+
),
170+
"cat bool": self._replace(y_mt_bin_pred, {0: True, 1: False}),
147171
},
148172
},
149173
"Xs": {

0 commit comments

Comments
 (0)