Skip to content

Commit 9758a3d

Browse files
authored
Merge pull request #410 from KhiopsML/407-separate-intepretation-dictionary-build-from-the-lever-variables-part
Update interpretation support to the Khiops Core alpha 10.7.3-a.0
2 parents b76aeb1 + 4f0bc4d commit 9758a3d

File tree

7 files changed

+447
-44
lines changed

7 files changed

+447
-44
lines changed

doc/samples/samples.rst

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ Samples
128128
129129
# Set the file paths
130130
dictionary_file_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.kdic")
131-
output_dir = os.path.join("kh_samples", "export_dictionary_file")
131+
output_dir = os.path.join("kh_samples", "export_dictionary_files")
132132
output_dictionary_file_path = os.path.join(output_dir, "ModifiedAdult.kdic")
133133
output_dictionary_json_path = os.path.join(output_dir, "ModifiedAdult.kdicj")
134134
alt_output_dictionary_json_path = os.path.join(output_dir, "AltModifiedAdult.kdicj")
@@ -686,6 +686,37 @@ Samples
686686
kh.interpret_predictor(predictor_file_path, "SNB_Adult", interpretor_file_path)
687687
688688
print(f"The interpretation model is '{interpretor_file_path}'")
689+
.. autofunction:: reinforce_predictor
690+
.. code-block:: python
691+
692+
# Imports
693+
import os
694+
from khiops import core as kh
695+
696+
dictionary_file_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.kdic")
697+
data_table_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
698+
output_dir = os.path.join("kh_samples", "reinforce_predictor")
699+
analysis_report_file_path = os.path.join(output_dir, "AnalysisResults.khj")
700+
reinforced_predictor_file_path = os.path.join(output_dir, "ReinforcedAdultModel.kdic")
701+
702+
# Build prediction model
703+
_, predictor_file_path = kh.train_predictor(
704+
dictionary_file_path,
705+
"Adult",
706+
data_table_path,
707+
"class",
708+
analysis_report_file_path,
709+
)
710+
711+
# Build reinforced predictor
712+
kh.reinforce_predictor(
713+
predictor_file_path,
714+
"SNB_Adult",
715+
reinforced_predictor_file_path,
716+
reinforcement_lever_variables=["occupation"],
717+
)
718+
719+
print(f"The reinforced predictor is '{reinforced_predictor_file_path}'")
689720
.. autofunction:: multiple_train_predictor
690721
.. code-block:: python
691722
@@ -1064,7 +1095,7 @@ Samples
10641095
dictionary_file_path = os.path.join(accidents_dir, "Accidents.kdic")
10651096
accidents_table_path = os.path.join(accidents_dir, "Accidents.txt")
10661097
vehicles_table_path = os.path.join(accidents_dir, "Vehicles.txt")
1067-
output_dir = os.path.join("kh_samples", "deploy_model_mt")
1098+
output_dir = os.path.join("kh_samples", "deploy_model_mt_with_interpretation")
10681099
report_file_path = os.path.join(output_dir, "AnalysisResults.khj")
10691100
interpretor_file_path = os.path.join(output_dir, "InterpretationModel.kdic")
10701101
output_data_table_path = os.path.join(output_dir, "InterpretedAccidents.txt")
@@ -1088,7 +1119,8 @@ Samples
10881119
model_dictionary_file_path,
10891120
"SNB_Accident",
10901121
interpretor_file_path,
1091-
reinforcement_target_value="NonLethal",
1122+
max_variable_importances=3,
1123+
importance_ranking="Individual",
10921124
)
10931125
10941126
# Deploy the interpretation model on the database
@@ -1101,6 +1133,53 @@ Samples
11011133
output_data_table_path,
11021134
additional_data_tables={"Vehicles": vehicles_table_path},
11031135
)
1136+
.. autofunction:: deploy_reinforced_model_mt
1137+
.. code-block:: python
1138+
1139+
# Imports
1140+
import os
1141+
from khiops import core as kh
1142+
1143+
# Set the file paths
1144+
accidents_dir = os.path.join(kh.get_samples_dir(), "AccidentsSummary")
1145+
dictionary_file_path = os.path.join(accidents_dir, "Accidents.kdic")
1146+
accidents_table_path = os.path.join(accidents_dir, "Accidents.txt")
1147+
vehicles_table_path = os.path.join(accidents_dir, "Vehicles.txt")
1148+
output_dir = os.path.join("kh_samples", "deploy_reinforced_model_mt")
1149+
report_file_path = os.path.join(output_dir, "AnalysisResults.khj")
1150+
reinforced_predictor_file_path = os.path.join(output_dir, "ReinforcedModel.kdic")
1151+
output_data_table_path = os.path.join(output_dir, "ReinforcedAccidents.txt")
1152+
1153+
# Train the predictor (see train_predictor_mt for details)
1154+
_, model_dictionary_file_path = kh.train_predictor(
1155+
dictionary_file_path,
1156+
"Accident",
1157+
accidents_table_path,
1158+
"Gravity",
1159+
report_file_path,
1160+
additional_data_tables={"Vehicles": vehicles_table_path},
1161+
max_trees=0,
1162+
)
1163+
1164+
# Reinforce the predictor
1165+
kh.reinforce_predictor(
1166+
model_dictionary_file_path,
1167+
"SNB_Accident",
1168+
reinforced_predictor_file_path,
1169+
reinforcement_target_value="NonLethal",
1170+
reinforcement_lever_variables=["InAgglomeration", "CollisionType"],
1171+
)
1172+
1173+
# Deploy the reinforced model on the database
1174+
# Besides the mandatory parameters, it is specified:
1175+
# - A python dictionary linking data paths to file paths for non-root tables
1176+
kh.deploy_model(
1177+
reinforced_predictor_file_path,
1178+
"Reinforcement_SNB_Accident",
1179+
accidents_table_path,
1180+
output_data_table_path,
1181+
additional_data_tables={"Vehicles": vehicles_table_path},
1182+
)
11041183
.. autofunction:: deploy_model_mt_snowflake
11051184
.. code-block:: python
11061185

khiops/core/api.py

Lines changed: 74 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -375,16 +375,6 @@ def _preprocess_task_arguments(task_args):
375375
if isinstance(task_args["selection_value"], (int, float)):
376376
task_args["selection_value"] = str(task_args["selection_value"])
377377

378-
# Discard the max_variable_importances interpretation parameters
379-
if "max_variable_importances" in task_args:
380-
if task_args["max_variable_importances"] is not None:
381-
warnings.warn(
382-
"The 'max_variable_importances' parameter of the "
383-
"'khiops.core.api.interpret_predictor' function is not supported "
384-
" yet. All model variables' importances are computed."
385-
)
386-
del task_args["max_variable_importances"]
387-
388378
# Detect and replace deprecated data-path syntax on additional_data_tables
389379
# Mutate task_args in the process
390380
for data_path_task_arg_name in (
@@ -880,9 +870,8 @@ def interpret_predictor(
880870
dictionary_file_path_or_domain,
881871
predictor_dictionary_name,
882872
interpretor_file_path,
883-
max_variable_importances=None,
884-
reinforcement_target_value="",
885-
reinforcement_lever_variables=None,
873+
max_variable_importances=100,
874+
importance_ranking="Global",
886875
log_file_path=None,
887876
output_scenario_path=None,
888877
task_file_path=None,
@@ -905,18 +894,80 @@ def interpret_predictor(
905894
Name of the predictor dictionary used while building the interpretation model.
906895
interpretor_file_path : str
907896
Path to the interpretor dictionary file.
908-
max_variable_importances : int, optional
897+
max_variable_importances : int, default 100
909898
Maximum number of variable importances to be selected in the interpretation
910-
model. If not set, then all the variables in the prediction model are
911-
considered.
912-
..note:: Not currently supported; not taken into account if set.
899+
model. If the predictor contains fewer variables than this number, then
900+
all the variables of the predictor are considered.
901+
importance_ranking : str, default "Global"
902+
Ranking of the Shapley values produced by the interpretor. Ca be one of:
903+
904+
- "Global": predictor variables are ranked by decreasing global importance.
905+
906+
- "Individual": predictor variables are ranked by decreasing individual
907+
Shapley value.
908+
... :
909+
See :ref:`core-api-common-params`.
910+
911+
Raises
912+
------
913+
`ValueError`
914+
Invalid values of an argument
915+
`TypeError`
916+
Invalid type of an argument
917+
918+
Examples
919+
--------
920+
See the following functions of the ``samples.py`` documentation script:
921+
- `samples.interpret_predictor()`
922+
- `samples.deploy_model_mt_with_interpretation()`
923+
"""
924+
# Save the task arguments
925+
# WARNING: Do not move this line, see the top of the "tasks" section for details
926+
task_args = locals()
927+
928+
# Run the task
929+
_run_task("interpret_predictor", task_args)
930+
931+
932+
def reinforce_predictor(
933+
dictionary_file_path_or_domain,
934+
predictor_dictionary_name,
935+
reinforced_predictor_file_path,
936+
reinforcement_target_value="",
937+
reinforcement_lever_variables=None,
938+
log_file_path=None,
939+
output_scenario_path=None,
940+
task_file_path=None,
941+
trace=False,
942+
stdout_file_path="",
943+
stderr_file_path="",
944+
max_cores=None,
945+
memory_limit_mb=None,
946+
temp_dir="",
947+
scenario_prologue="",
948+
**kwargs,
949+
):
950+
r"""Builds a reinforced predictor from a predictor
951+
952+
A reinforced predictor is a model which increases the importance of specified lever
953+
variables in order to increase the probability of occurrence of the specified target
954+
value.
955+
956+
Parameters
957+
----------
958+
dictionary_file_path_or_domain : str or `.DictionaryDomain`
959+
Path of a Khiops dictionary file or a DictionaryDomain object.
960+
predictor_dictionary_name : str
961+
Name of the predictor dictionary used while building the reinforced predictor.
962+
reinforced_predictor_file_path : str
963+
Path to the reinforced predictor dictionary file.
913964
reinforcement_target_value : str, default ""
914965
If this target value is specified, then its probability of occurrence is
915966
tentatively increased.
916-
reinforcement_lever_variables : list of str, optional
967+
reinforcement_lever_variables : list of str
917968
The names of variables to use as lever variables while building the
918-
interpretation model. Min length: 0. Max length: the total number of variables
919-
in the prediction model. If not specified, all variables are used.
969+
reinforced predictor. Min length: 1. Max length: the total number of variables
970+
in the prediction model.
920971
... :
921972
See :ref:`core-api-common-params`.
922973
@@ -930,14 +981,15 @@ def interpret_predictor(
930981
Examples
931982
--------
932983
See the following functions of the ``samples.py`` documentation script:
933-
- `samples.interpret_predictor()`
984+
- `samples.reinforce_predictor()`
985+
- `samples.deploy_reinforced_model_mt()`
934986
"""
935987
# Save the task arguments
936988
# WARNING: Do not move this line, see the top of the "tasks" section for details
937989
task_args = locals()
938990

939991
# Run the task
940-
_run_task("interpret_predictor", task_args)
992+
_run_task("reinforce_predictor", task_args)
941993

942994

943995
def evaluate_predictor(

khiops/core/internals/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
extract_keys_from_data_table,
2323
interpret_predictor,
2424
prepare_coclustering_deployment,
25+
reinforce_predictor,
2526
simplify_coclustering,
2627
sort_data_table,
2728
train_coclustering,
@@ -43,6 +44,7 @@
4344
extract_clusters,
4445
extract_keys_from_data_table,
4546
interpret_predictor,
47+
reinforce_predictor,
4648
prepare_coclustering_deployment,
4749
simplify_coclustering,
4850
sort_data_table,

khiops/core/internals/tasks/interpret_predictor.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
######################################################################################
77
"""interpret_predictor task family"""
88
from khiops.core.internals import task as tm
9-
from khiops.core.internals.types import ListType, StringLikeType
9+
from khiops.core.internals.types import IntType, StringLikeType
1010

1111
# Disable long lines to have readable scenarios
1212
# pylint: disable=line-too-long
@@ -21,8 +21,8 @@
2121
("interpretor_file_path", StringLikeType),
2222
],
2323
[
24-
("reinforcement_target_value", StringLikeType, ""),
25-
("reinforcement_lever_variables", ListType(StringLikeType), None),
24+
("max_variable_importances", IntType, 100),
25+
("importance_ranking", StringLikeType, "Global"),
2626
],
2727
["dictionary_file_path", "interpretor_file_path"],
2828
# pylint: disable=line-too-long
@@ -38,14 +38,12 @@
3838
3939
// Interpret model
4040
LearningTools.InterpretPredictor
41-
HowParameter.HowClass __reinforcement_target_value__
4241
43-
__DICT__
44-
__reinforcement_lever_variables__
45-
HowParameter.leverVariablesSpecView.UnselectAll
46-
HowParameter.leverVariablesSpecView.AttributeSpecs.List.Key
47-
HowParameter.leverVariablesSpecView.AttributeSpecs.Used
48-
__END_DICT__
42+
// Number of predictor variables exploited in the interpretation model
43+
ContributionAttributeNumber __max_variable_importances__
44+
45+
// Ranking of the Shapley value produced by the interpretation model
46+
ShapleyValueRanking __importance_ranking__
4947
5048
// Build interpretation dictionary
5149
BuildInterpretationClass
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
######################################################################################
2+
# Copyright (c) 2023-2025 Orange. All rights reserved. #
3+
# This software is distributed under the BSD 3-Clause-clear License, the text of #
4+
# which is available at https://spdx.org/licenses/BSD-3-Clause-Clear.html or #
5+
# see the "LICENSE.md" file for more details. #
6+
######################################################################################
7+
"""reinforce_predictor task family"""
8+
from khiops.core.internals import task as tm
9+
from khiops.core.internals.types import ListType, StringLikeType
10+
11+
# Disable long lines to have readable scenarios
12+
# pylint: disable=line-too-long
13+
TASKS = [
14+
tm.KhiopsTask(
15+
"reinforce_predictor",
16+
"khiops",
17+
"10.7.3-a.0",
18+
[
19+
("dictionary_file_path", StringLikeType),
20+
("predictor_dictionary_name", StringLikeType),
21+
("reinforced_predictor_file_path", StringLikeType),
22+
],
23+
[
24+
("reinforcement_target_value", StringLikeType, ""),
25+
("reinforcement_lever_variables", ListType(StringLikeType), None),
26+
],
27+
["dictionary_file_path", "reinforced_predictor_file_path"],
28+
# pylint: disable=line-too-long
29+
# fmt: off
30+
"""
31+
// Dictionary file and class settings
32+
ClassManagement.OpenFile
33+
ClassFileName __dictionary_file_path__
34+
OK
35+
36+
// Reinforcement settings
37+
TrainDatabase.ClassName __predictor_dictionary_name__
38+
39+
// Reinforce model
40+
LearningTools.ReinforcePredictor
41+
ReinforcedTargetValue __reinforcement_target_value__
42+
43+
LeverAttributes.UnselectAll
44+
__DICT__
45+
__reinforcement_lever_variables__
46+
LeverAttributes.List.Key
47+
LeverAttributes.Used
48+
__END_DICT__
49+
50+
// Build reinforced predictor
51+
BuildReinforcementClass
52+
53+
// Output settings
54+
ClassFileName __reinforced_predictor_file_path__
55+
OK
56+
Exit
57+
""",
58+
# fmt: on
59+
),
60+
]

0 commit comments

Comments
 (0)