Skip to content

Commit bd9843e

Browse files
Remove n_pairs parameter from KhiopsRegressor (#313)
It was never supported.
1 parent 1012a6f commit bd9843e

File tree

3 files changed

+64
-92
lines changed

3 files changed

+64
-92
lines changed

khiops/sklearn/estimators.py

Lines changed: 63 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,37 @@ def _check_numerical_target_type(ds):
198198
)
199199

200200

201+
def _check_pair_parameters(estimator):
202+
assert isinstance(estimator, (KhiopsClassifier, KhiopsEncoder)), type_error_message(
203+
"estimator", estimator, KhiopsClassifier, KhiopsEncoder
204+
)
205+
if not isinstance(estimator.n_pairs, int):
206+
raise TypeError(type_error_message("n_pairs", estimator.n_pairs, int))
207+
if estimator.n_pairs < 0:
208+
raise ValueError("'n_pairs' must be positive")
209+
if estimator.specific_pairs is not None:
210+
if not is_list_like(estimator.specific_pairs):
211+
raise TypeError(
212+
type_error_message(
213+
"specific_pairs", estimator.specific_pairs, "list-like"
214+
)
215+
)
216+
else:
217+
for pair in estimator.specific_pairs:
218+
if not isinstance(pair, tuple):
219+
raise TypeError(type_error_message(pair, pair, tuple))
220+
if not isinstance(estimator.all_possible_pairs, bool):
221+
raise TypeError(
222+
type_error_message("all_possible_pairs", estimator.all_possible_pairs, bool)
223+
)
224+
225+
# Check 'group_target_value' parameter
226+
if not isinstance(estimator.group_target_value, bool):
227+
raise TypeError(
228+
type_error_message("group_target_value", estimator.group_target_value, bool)
229+
)
230+
231+
201232
def _cleanup_dir(target_dir):
202233
"""Cleanups a directory with only files in it
203234
@@ -1379,7 +1410,6 @@ class KhiopsSupervisedEstimator(KhiopsEstimator):
13791410
def __init__(
13801411
self,
13811412
n_features=100,
1382-
n_pairs=0,
13831413
n_trees=10,
13841414
specific_pairs=None,
13851415
all_possible_pairs=True,
@@ -1398,7 +1428,6 @@ def __init__(
13981428
internal_sort=internal_sort,
13991429
)
14001430
self.n_features = n_features
1401-
self.n_pairs = n_pairs
14021431
self.n_trees = n_trees
14031432
self.specific_pairs = specific_pairs
14041433
self.all_possible_pairs = all_possible_pairs
@@ -1489,25 +1518,6 @@ def _fit_check_params(self, ds, **kwargs):
14891518
raise TypeError(type_error_message("n_trees", self.n_trees, int))
14901519
if self.n_trees < 0:
14911520
raise ValueError("'n_trees' must be positive")
1492-
if not isinstance(self.n_pairs, int):
1493-
raise TypeError(type_error_message("n_pairs", self.n_pairs, int))
1494-
if self.n_pairs < 0:
1495-
raise ValueError("'n_pairs' must be positive")
1496-
if self.specific_pairs is not None:
1497-
if not is_list_like(self.specific_pairs):
1498-
raise TypeError(
1499-
type_error_message(
1500-
"specific_pairs", self.specific_pairs, "list-like"
1501-
)
1502-
)
1503-
else:
1504-
for pair in self.specific_pairs:
1505-
if not isinstance(pair, tuple):
1506-
raise TypeError(type_error_message(pair, pair, tuple))
1507-
if not isinstance(self.all_possible_pairs, bool):
1508-
raise TypeError(
1509-
type_error_message("all_possible_pairs", self.all_possible_pairs, bool)
1510-
)
15111521
if self.construction_rules is not None:
15121522
if not is_list_like(self.construction_rules):
15131523
raise TypeError(
@@ -1594,7 +1604,6 @@ def _fit_prepare_training_function_inputs(self, ds, computation_dir):
15941604

15951605
# Rename parameters to be compatible with khiops.core
15961606
kwargs["max_constructed_variables"] = kwargs.pop("n_features")
1597-
kwargs["max_pairs"] = kwargs.pop("n_pairs")
15981607
kwargs["max_trees"] = kwargs.pop("n_trees")
15991608

16001609
# Add the additional_data_tables parameter
@@ -1774,7 +1783,6 @@ class KhiopsPredictor(KhiopsSupervisedEstimator):
17741783
def __init__(
17751784
self,
17761785
n_features=100,
1777-
n_pairs=0,
17781786
n_trees=10,
17791787
n_selected_features=0,
17801788
n_evaluated_features=0,
@@ -1789,7 +1797,6 @@ def __init__(
17891797
):
17901798
super().__init__(
17911799
n_features=n_features,
1792-
n_pairs=n_pairs,
17931800
n_trees=n_trees,
17941801
specific_pairs=specific_pairs,
17951802
all_possible_pairs=all_possible_pairs,
@@ -2081,19 +2088,19 @@ def __init__(
20812088
):
20822089
super().__init__(
20832090
n_features=n_features,
2084-
n_pairs=n_pairs,
20852091
n_trees=n_trees,
20862092
n_selected_features=n_selected_features,
20872093
n_evaluated_features=n_evaluated_features,
2088-
specific_pairs=specific_pairs,
2089-
all_possible_pairs=all_possible_pairs,
20902094
construction_rules=construction_rules,
20912095
verbose=verbose,
20922096
output_dir=output_dir,
20932097
auto_sort=auto_sort,
20942098
key=key,
20952099
internal_sort=internal_sort,
20962100
)
2101+
self.n_pairs = n_pairs
2102+
self.specific_pairs = specific_pairs
2103+
self.all_possible_pairs = all_possible_pairs
20972104
self.group_target_value = group_target_value
20982105
self._khiops_model_prefix = "SNB_"
20992106
self._predicted_target_meta_data_tag = "Prediction"
@@ -2140,11 +2147,19 @@ def _fit_check_params(self, ds, **kwargs):
21402147
# Call parent method
21412148
super()._fit_check_params(ds, **kwargs)
21422149

2143-
# Check 'group_target_value' parameter
2144-
if not isinstance(self.group_target_value, bool):
2145-
raise TypeError(
2146-
type_error_message("group_target_value", self.group_target_value, bool)
2147-
)
2150+
# Check the pair related parameters
2151+
_check_pair_parameters(self)
2152+
2153+
def _fit_prepare_training_function_inputs(self, ds, computation_dir):
2154+
# Call the parent method
2155+
args, kwargs = super()._fit_prepare_training_function_inputs(
2156+
ds, computation_dir
2157+
)
2158+
2159+
# Rename parameters to be compatible with khiops.core
2160+
kwargs["max_pairs"] = kwargs.pop("n_pairs")
2161+
2162+
return args, kwargs
21482163

21492164
def fit(self, X, y, **kwargs):
21502165
"""Fits a Selective Naive Bayes classifier according to X, y
@@ -2409,27 +2424,12 @@ class KhiopsRegressor(RegressorMixin, KhiopsPredictor):
24092424
n_features : int, default 100
24102425
*Multi-table only* : Maximum number of multi-table aggregate features to
24112426
construct. See :doc:`/multi_table_primer` for more details.
2412-
n_pairs : int, default 0
2413-
Maximum number of pair features to construct. These features are 2D grid
2414-
partitions of univariate feature pairs. The grid is optimized such that in each
2415-
cell the target distribution is well approximated by a constant histogram. Only
2416-
pairs that are jointly more informative than their marginals may be taken into
2417-
account in the regressor.
24182427
n_selected_features : int, default 0
24192428
Maximum number of features to be selected in the SNB predictor. If equal to
24202429
0 it selects all the features kept in the training.
24212430
n_evaluated_features : int, default 0
24222431
Maximum number of features to be evaluated in the SNB predictor training. If
24232432
equal to 0 it evaluates all informative features.
2424-
specific_pairs : list of tuple, optional
2425-
User-specified pairs as a list of 2-tuples of feature names. If a given tuple
2426-
contains only one non-empty feature name, then it generates all the pairs
2427-
containing it (within the maximum limit ``n_pairs``). These pairs have top
2428-
priority: they are constructed first.
2429-
all_possible_pairs : bool, default ``True``
2430-
If ``True`` tries to create all possible pairs within the limit ``n_pairs``.
2431-
Pairs specified with ``specific_pairs`` have top priority: they are constructed
2432-
first.
24332433
construction_rules : list of str, optional
24342434
Allowed rules for the automatic feature construction. If not set, it uses all
24352435
possible rules.
@@ -2509,12 +2509,9 @@ class KhiopsRegressor(RegressorMixin, KhiopsPredictor):
25092509
def __init__(
25102510
self,
25112511
n_features=100,
2512-
n_pairs=0,
25132512
n_trees=0,
25142513
n_selected_features=0,
25152514
n_evaluated_features=0,
2516-
specific_pairs=None,
2517-
all_possible_pairs=True,
25182515
construction_rules=None,
25192516
verbose=False,
25202517
output_dir=None,
@@ -2524,12 +2521,9 @@ def __init__(
25242521
):
25252522
super().__init__(
25262523
n_features=n_features,
2527-
n_pairs=n_pairs,
25282524
n_trees=n_trees,
25292525
n_selected_features=n_selected_features,
25302526
n_evaluated_features=n_evaluated_features,
2531-
specific_pairs=specific_pairs,
2532-
all_possible_pairs=all_possible_pairs,
25332527
construction_rules=construction_rules,
25342528
verbose=verbose,
25352529
output_dir=output_dir,
@@ -2821,17 +2815,17 @@ def __init__(
28212815
):
28222816
super().__init__(
28232817
n_features=n_features,
2824-
n_pairs=n_pairs,
28252818
n_trees=n_trees,
2826-
specific_pairs=specific_pairs,
2827-
all_possible_pairs=all_possible_pairs,
28282819
construction_rules=construction_rules,
28292820
verbose=verbose,
28302821
output_dir=output_dir,
28312822
auto_sort=auto_sort,
28322823
key=key,
28332824
internal_sort=internal_sort,
28342825
)
2826+
self.n_pairs = n_pairs
2827+
self.specific_pairs = specific_pairs
2828+
self.all_possible_pairs = all_possible_pairs
28352829
self.categorical_target = categorical_target
28362830
self.group_target_value = group_target_value
28372831
self.transform_type_categorical = transform_type_categorical
@@ -2904,6 +2898,9 @@ def _fit_check_params(self, ds, **kwargs):
29042898
# Call parent method
29052899
super()._fit_check_params(ds, **kwargs)
29062900

2901+
# Check the pair related parameters
2902+
_check_pair_parameters(self)
2903+
29072904
# Check 'transform_type_categorical' parameter
29082905
if not isinstance(self.transform_type_categorical, str):
29092906
raise TypeError(
@@ -2922,6 +2919,15 @@ def _fit_check_params(self, ds, **kwargs):
29222919
)
29232920
self._numerical_transform_method() # Raises ValueError if invalid
29242921

2922+
# Check 'transform_type_pairs' parameter
2923+
if not isinstance(self.transform_type_pairs, str):
2924+
raise TypeError(
2925+
type_error_message(
2926+
"transform_type_pairs", self.transform_type_pairs, str
2927+
)
2928+
)
2929+
self._pairs_transform_method() # Raises ValueError if invalid
2930+
29252931
# Check coherence between transformation types and tree number
29262932
if (
29272933
self.transform_type_categorical is None
@@ -2932,14 +2938,6 @@ def _fit_check_params(self, ds, **kwargs):
29322938
"transform_type_categorical and transform_type_numerical "
29332939
"cannot be both None with n_trees == 0."
29342940
)
2935-
# Check 'transform_type_pairs' parameter
2936-
if not isinstance(self.transform_type_pairs, str):
2937-
raise TypeError(
2938-
type_error_message(
2939-
"transform_type_pairs", self.transform_type_pairs, str
2940-
)
2941-
)
2942-
self._pairs_transform_method() # Raises ValueError if invalid
29432941

29442942
# Check 'informative_features_only' parameter
29452943
if not isinstance(self.informative_features_only, bool):
@@ -3028,6 +3026,7 @@ def _fit_prepare_training_function_inputs(self, ds, computation_dir):
30283026
)
30293027
# Rename encoder parameters, delete unused ones
30303028
# to be compatible with khiops.core
3029+
kwargs["max_pairs"] = kwargs.pop("n_pairs")
30313030
kwargs["keep_initial_categorical_variables"] = kwargs["keep_initial_variables"]
30323031
kwargs["keep_initial_numerical_variables"] = kwargs.pop(
30333032
"keep_initial_variables"

tests/test_estimator_attributes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_regressor_attributes_monotable(self):
214214
adult_df = pd.read_csv(adult_dataset_path, sep="\t").sample(750)
215215
X = adult_df.drop("age", axis=1)
216216
y = adult_df["age"]
217-
khr_adult = KhiopsRegressor(n_trees=0, n_pairs=5)
217+
khr_adult = KhiopsRegressor(n_trees=0)
218218
with warnings.catch_warnings():
219219
warnings.filterwarnings(
220220
action="ignore",

tests/test_sklearn.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1114,12 +1114,9 @@ def setUpClass(cls):
11141114
"field_separator": "\t",
11151115
"detect_format": False,
11161116
"header_line": True,
1117-
"max_pairs": 1,
11181117
"max_trees": 0,
11191118
"max_selected_variables": 1,
11201119
"max_evaluated_variables": 3,
1121-
"specific_pairs": [("age", "race")],
1122-
"all_possible_pairs": False,
11231120
"construction_rules": ["TableMode", "TableSelection"],
11241121
"additional_data_tables": {},
11251122
}
@@ -1206,12 +1203,9 @@ def setUpClass(cls):
12061203
"field_separator": "\t",
12071204
"detect_format": False,
12081205
"header_line": True,
1209-
"max_pairs": 1,
12101206
"max_trees": 0,
12111207
"max_selected_variables": 1,
12121208
"max_evaluated_variables": 3,
1213-
"specific_pairs": [("age", "race")],
1214-
"all_possible_pairs": False,
12151209
"construction_rules": ["TableMode", "TableSelection"],
12161210
"additional_data_tables": {},
12171211
}
@@ -1306,12 +1300,9 @@ def setUpClass(cls):
13061300
"detect_format": False,
13071301
"header_line": True,
13081302
"max_constructed_variables": 10,
1309-
"max_pairs": 1,
13101303
"max_trees": 0,
13111304
"max_selected_variables": 1,
13121305
"max_evaluated_variables": 3,
1313-
"specific_pairs": [],
1314-
"all_possible_pairs": False,
13151306
"construction_rules": ["TableMode", "TableSelection"],
13161307
"additional_data_tables": {
13171308
"SpliceJunction`SpliceJunctionDNA"
@@ -1416,12 +1407,9 @@ def setUpClass(cls):
14161407
"detect_format": False,
14171408
"header_line": True,
14181409
"max_constructed_variables": 10,
1419-
"max_pairs": 1,
14201410
"max_trees": 0,
14211411
"max_selected_variables": 1,
14221412
"max_evaluated_variables": 3,
1423-
"specific_pairs": [],
1424-
"all_possible_pairs": False,
14251413
"construction_rules": ["TableMode", "TableSelection"],
14261414
"log_file_path": os.path.join(
14271415
cls.output_dir, "khiops.log"
@@ -2407,11 +2395,8 @@ def test_parameter_transfer_regressor_fit_from_monotable_dataframe(self):
24072395
schema_type="monotable",
24082396
source_type="dataframe",
24092397
extra_estimator_kwargs={
2410-
"n_pairs": 1,
24112398
"n_selected_features": 1,
24122399
"n_evaluated_features": 3,
2413-
"specific_pairs": [("age", "race")],
2414-
"all_possible_pairs": False,
24152400
"construction_rules": ["TableMode", "TableSelection"],
24162401
},
24172402
)
@@ -2426,11 +2411,8 @@ def test_parameter_transfer_regressor_fit_from_monotable_dataframe_with_df_y(
24262411
schema_type="monotable",
24272412
source_type="dataframe_xy",
24282413
extra_estimator_kwargs={
2429-
"n_pairs": 1,
24302414
"n_selected_features": 1,
24312415
"n_evaluated_features": 3,
2432-
"specific_pairs": [("age", "race")],
2433-
"all_possible_pairs": False,
24342416
"construction_rules": ["TableMode", "TableSelection"],
24352417
},
24362418
)
@@ -2443,11 +2425,8 @@ def test_parameter_transfer_regressor_fit_from_monotable_file_dataset(self):
24432425
schema_type="monotable",
24442426
source_type="file_dataset",
24452427
extra_estimator_kwargs={
2446-
"n_pairs": 1,
24472428
"n_selected_features": 1,
24482429
"n_evaluated_features": 3,
2449-
"specific_pairs": [("age", "race")],
2450-
"all_possible_pairs": False,
24512430
"construction_rules": ["TableMode", "TableSelection"],
24522431
},
24532432
)
@@ -2461,12 +2440,9 @@ def test_parameter_transfer_regressor_fit_from_multitable_dataframe(self):
24612440
source_type="dataframe",
24622441
extra_estimator_kwargs={
24632442
"n_features": 10,
2464-
"n_pairs": 1,
24652443
"n_trees": 0,
24662444
"n_selected_features": 1,
24672445
"n_evaluated_features": 3,
2468-
"specific_pairs": [],
2469-
"all_possible_pairs": False,
24702446
"construction_rules": ["TableMode", "TableSelection"],
24712447
},
24722448
)
@@ -2480,12 +2456,9 @@ def test_parameter_transfer_regressor_fit_from_multitable_file_dataset(self):
24802456
source_type="file_dataset",
24812457
extra_estimator_kwargs={
24822458
"n_features": 10,
2483-
"n_pairs": 1,
24842459
"n_trees": 0,
24852460
"n_selected_features": 1,
24862461
"n_evaluated_features": 3,
2487-
"specific_pairs": [],
2488-
"all_possible_pairs": False,
24892462
"construction_rules": ["TableMode", "TableSelection"],
24902463
},
24912464
)

0 commit comments

Comments
 (0)