Skip to content

Commit 4f0bc4d

Browse files
committed
Add reinforce_predictor Core API function as supported by 10.7.3-a.0
Also add relevant samples of its usage.
1 parent 0287d7d commit 4f0bc4d

File tree

6 files changed

+413
-0
lines changed

6 files changed

+413
-0
lines changed

doc/samples/samples.rst

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,37 @@ Samples
686686
kh.interpret_predictor(predictor_file_path, "SNB_Adult", interpretor_file_path)
687687
688688
print(f"The interpretation model is '{interpretor_file_path}'")
689+
.. autofunction:: reinforce_predictor
690+
.. code-block:: python
691+
692+
# Imports
693+
import os
694+
from khiops import core as kh
695+
696+
dictionary_file_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.kdic")
697+
data_table_path = os.path.join(kh.get_samples_dir(), "Adult", "Adult.txt")
698+
output_dir = os.path.join("kh_samples", "reinforce_predictor")
699+
analysis_report_file_path = os.path.join(output_dir, "AnalysisResults.khj")
700+
reinforced_predictor_file_path = os.path.join(output_dir, "ReinforcedAdultModel.kdic")
701+
702+
# Build prediction model
703+
_, predictor_file_path = kh.train_predictor(
704+
dictionary_file_path,
705+
"Adult",
706+
data_table_path,
707+
"class",
708+
analysis_report_file_path,
709+
)
710+
711+
# Build reinforced predictor
712+
kh.reinforce_predictor(
713+
predictor_file_path,
714+
"SNB_Adult",
715+
reinforced_predictor_file_path,
716+
reinforcement_lever_variables=["occupation"],
717+
)
718+
719+
print(f"The reinforced predictor is '{reinforced_predictor_file_path}'")
689720
.. autofunction:: multiple_train_predictor
690721
.. code-block:: python
691722
@@ -1102,6 +1133,53 @@ Samples
11021133
output_data_table_path,
11031134
additional_data_tables={"Vehicles": vehicles_table_path},
11041135
)
1136+
.. autofunction:: deploy_reinforced_model_mt
1137+
.. code-block:: python
1138+
1139+
# Imports
1140+
import os
1141+
from khiops import core as kh
1142+
1143+
# Set the file paths
1144+
accidents_dir = os.path.join(kh.get_samples_dir(), "AccidentsSummary")
1145+
dictionary_file_path = os.path.join(accidents_dir, "Accidents.kdic")
1146+
accidents_table_path = os.path.join(accidents_dir, "Accidents.txt")
1147+
vehicles_table_path = os.path.join(accidents_dir, "Vehicles.txt")
1148+
output_dir = os.path.join("kh_samples", "deploy_reinforced_model_mt")
1149+
report_file_path = os.path.join(output_dir, "AnalysisResults.khj")
1150+
reinforced_predictor_file_path = os.path.join(output_dir, "ReinforcedModel.kdic")
1151+
output_data_table_path = os.path.join(output_dir, "ReinforcedAccidents.txt")
1152+
1153+
# Train the predictor (see train_predictor_mt for details)
1154+
_, model_dictionary_file_path = kh.train_predictor(
1155+
dictionary_file_path,
1156+
"Accident",
1157+
accidents_table_path,
1158+
"Gravity",
1159+
report_file_path,
1160+
additional_data_tables={"Vehicles": vehicles_table_path},
1161+
max_trees=0,
1162+
)
1163+
1164+
# Reinforce the predictor
1165+
kh.reinforce_predictor(
1166+
model_dictionary_file_path,
1167+
"SNB_Accident",
1168+
reinforced_predictor_file_path,
1169+
reinforcement_target_value="NonLethal",
1170+
reinforcement_lever_variables=["InAgglomeration", "CollisionType"],
1171+
)
1172+
1173+
# Deploy the reinforced model on the database
1174+
# Besides the mandatory parameters, it is specified:
1175+
# - A python dictionary linking data paths to file paths for non-root tables
1176+
kh.deploy_model(
1177+
reinforced_predictor_file_path,
1178+
"Reinforcement_SNB_Accident",
1179+
accidents_table_path,
1180+
output_data_table_path,
1181+
additional_data_tables={"Vehicles": vehicles_table_path},
1182+
)
11051183
.. autofunction:: deploy_model_mt_snowflake
11061184
.. code-block:: python
11071185

khiops/core/api.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,69 @@ def interpret_predictor(
929929
_run_task("interpret_predictor", task_args)
930930

931931

932+
def reinforce_predictor(
933+
dictionary_file_path_or_domain,
934+
predictor_dictionary_name,
935+
reinforced_predictor_file_path,
936+
reinforcement_target_value="",
937+
reinforcement_lever_variables=None,
938+
log_file_path=None,
939+
output_scenario_path=None,
940+
task_file_path=None,
941+
trace=False,
942+
stdout_file_path="",
943+
stderr_file_path="",
944+
max_cores=None,
945+
memory_limit_mb=None,
946+
temp_dir="",
947+
scenario_prologue="",
948+
**kwargs,
949+
):
950+
r"""Builds a reinforced predictor from a predictor
951+
952+
A reinforced predictor is a model which increases the importance of specified lever
953+
variables in order to increase the probability of occurrence of the specified target
954+
value.
955+
956+
Parameters
957+
----------
958+
dictionary_file_path_or_domain : str or `.DictionaryDomain`
959+
Path of a Khiops dictionary file or a DictionaryDomain object.
960+
predictor_dictionary_name : str
961+
Name of the predictor dictionary used while building the reinforced predictor.
962+
reinforced_predictor_file_path : str
963+
Path to the reinforced predictor dictionary file.
964+
reinforcement_target_value : str, default ""
965+
If this target value is specified, then its probability of occurrence is
966+
tentatively increased.
967+
reinforcement_lever_variables : list of str
968+
The names of variables to use as lever variables while building the
969+
reinforced predictor. Min length: 1. Max length: the total number of variables
970+
in the prediction model.
971+
... :
972+
See :ref:`core-api-common-params`.
973+
974+
Raises
975+
------
976+
`ValueError`
977+
Invalid values of an argument
978+
`TypeError`
979+
Invalid type of an argument
980+
981+
Examples
982+
--------
983+
See the following functions of the ``samples.py`` documentation script:
984+
- `samples.reinforce_predictor()`
985+
- `samples.deploy_reinforced_model_mt()`
986+
"""
987+
# Save the task arguments
988+
# WARNING: Do not move this line, see the top of the "tasks" section for details
989+
task_args = locals()
990+
991+
# Run the task
992+
_run_task("reinforce_predictor", task_args)
993+
994+
932995
def evaluate_predictor(
933996
dictionary_file_path_or_domain,
934997
train_dictionary_name,

khiops/core/internals/tasks/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
extract_keys_from_data_table,
2323
interpret_predictor,
2424
prepare_coclustering_deployment,
25+
reinforce_predictor,
2526
simplify_coclustering,
2627
sort_data_table,
2728
train_coclustering,
@@ -43,6 +44,7 @@
4344
extract_clusters,
4445
extract_keys_from_data_table,
4546
interpret_predictor,
47+
reinforce_predictor,
4648
prepare_coclustering_deployment,
4749
simplify_coclustering,
4850
sort_data_table,
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
######################################################################################
2+
# Copyright (c) 2023-2025 Orange. All rights reserved. #
3+
# This software is distributed under the BSD 3-Clause-clear License, the text of #
4+
# which is available at https://spdx.org/licenses/BSD-3-Clause-Clear.html or #
5+
# see the "LICENSE.md" file for more details. #
6+
######################################################################################
7+
"""reinforce_predictor task family"""
8+
from khiops.core.internals import task as tm
9+
from khiops.core.internals.types import ListType, StringLikeType
10+
11+
# Disable long lines to have readable scenarios
12+
# pylint: disable=line-too-long
13+
TASKS = [
14+
tm.KhiopsTask(
15+
"reinforce_predictor",
16+
"khiops",
17+
"10.7.3-a.0",
18+
[
19+
("dictionary_file_path", StringLikeType),
20+
("predictor_dictionary_name", StringLikeType),
21+
("reinforced_predictor_file_path", StringLikeType),
22+
],
23+
[
24+
("reinforcement_target_value", StringLikeType, ""),
25+
("reinforcement_lever_variables", ListType(StringLikeType), None),
26+
],
27+
["dictionary_file_path", "reinforced_predictor_file_path"],
28+
# pylint: disable=line-too-long
29+
# fmt: off
30+
"""
31+
// Dictionary file and class settings
32+
ClassManagement.OpenFile
33+
ClassFileName __dictionary_file_path__
34+
OK
35+
36+
// Reinforcement settings
37+
TrainDatabase.ClassName __predictor_dictionary_name__
38+
39+
// Reinforce model
40+
LearningTools.ReinforcePredictor
41+
ReinforcedTargetValue __reinforcement_target_value__
42+
43+
LeverAttributes.UnselectAll
44+
__DICT__
45+
__reinforcement_lever_variables__
46+
LeverAttributes.List.Key
47+
LeverAttributes.Used
48+
__END_DICT__
49+
50+
// Build reinforced predictor
51+
BuildReinforcementClass
52+
53+
// Output settings
54+
ClassFileName __reinforced_predictor_file_path__
55+
OK
56+
Exit
57+
""",
58+
# fmt: on
59+
),
60+
]

khiops/samples/samples.ipynb

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,50 @@
908908
"print(f\"The interpretation model is '{interpretor_file_path}'\")"
909909
]
910910
},
911+
{
912+
"cell_type": "markdown",
913+
"metadata": {},
914+
"source": [
915+
"### `reinforce_predictor()`\n\n",
916+
"Builds reinforced predictor for existing predictor\n\n The reinforced predictor produces the following reinforcement variables for the\n specified target value to reinforce (i.e. whose probability of occurrence is\n tentatively increased):\n\n - initial score, containing the conditional probability of the target value before\n reinforcement\n - four variables are output in decreasing reinforcement value: name of the lever\n variable, reinforcement part, final score after reinforcement, and class change\n tag.\n\n It calls `~.api.train_predictor` and `~.api.reinforce_predictor` only with\n their mandatory parameters.\n \n"
917+
]
918+
},
919+
{
920+
"cell_type": "code",
921+
"execution_count": null,
922+
"metadata": {},
923+
"outputs": [],
924+
"source": [
925+
"# Imports\n",
926+
"import os\n",
927+
"from khiops import core as kh\n",
928+
"\n",
929+
"dictionary_file_path = os.path.join(kh.get_samples_dir(), \"Adult\", \"Adult.kdic\")\n",
930+
"data_table_path = os.path.join(kh.get_samples_dir(), \"Adult\", \"Adult.txt\")\n",
931+
"output_dir = os.path.join(\"kh_samples\", \"reinforce_predictor\")\n",
932+
"analysis_report_file_path = os.path.join(output_dir, \"AnalysisResults.khj\")\n",
933+
"reinforced_predictor_file_path = os.path.join(output_dir, \"ReinforcedAdultModel.kdic\")\n",
934+
"\n",
935+
"# Build prediction model\n",
936+
"_, predictor_file_path = kh.train_predictor(\n",
937+
" dictionary_file_path,\n",
938+
" \"Adult\",\n",
939+
" data_table_path,\n",
940+
" \"class\",\n",
941+
" analysis_report_file_path,\n",
942+
")\n",
943+
"\n",
944+
"# Build reinforced predictor\n",
945+
"kh.reinforce_predictor(\n",
946+
" predictor_file_path,\n",
947+
" \"SNB_Adult\",\n",
948+
" reinforced_predictor_file_path,\n",
949+
" reinforcement_lever_variables=[\"occupation\"],\n",
950+
")\n",
951+
"\n",
952+
"print(f\"The reinforced predictor is '{reinforced_predictor_file_path}'\")"
953+
]
954+
},
911955
{
912956
"cell_type": "markdown",
913957
"metadata": {},
@@ -1454,6 +1498,66 @@
14541498
")"
14551499
]
14561500
},
1501+
{
1502+
"cell_type": "markdown",
1503+
"metadata": {},
1504+
"source": [
1505+
"### `deploy_reinforced_model_mt()`\n\n",
1506+
"Deploys a multi-table reinforced model in the simplest way possible\n\n It is a call to `~.api.deploy_model` with additional parameters related to\n the lever variables.\n\n In this example, a reinforced Selective Naive Bayes (SNB) model is\n deployed by applying its associated dictionary to the input database.\n The reinforced model predictions are written to the output data table.\n \n"
1507+
]
1508+
},
1509+
{
1510+
"cell_type": "code",
1511+
"execution_count": null,
1512+
"metadata": {},
1513+
"outputs": [],
1514+
"source": [
1515+
"# Imports\n",
1516+
"import os\n",
1517+
"from khiops import core as kh\n",
1518+
"\n",
1519+
"# Set the file paths\n",
1520+
"accidents_dir = os.path.join(kh.get_samples_dir(), \"AccidentsSummary\")\n",
1521+
"dictionary_file_path = os.path.join(accidents_dir, \"Accidents.kdic\")\n",
1522+
"accidents_table_path = os.path.join(accidents_dir, \"Accidents.txt\")\n",
1523+
"vehicles_table_path = os.path.join(accidents_dir, \"Vehicles.txt\")\n",
1524+
"output_dir = os.path.join(\"kh_samples\", \"deploy_reinforced_model_mt\")\n",
1525+
"report_file_path = os.path.join(output_dir, \"AnalysisResults.khj\")\n",
1526+
"reinforced_predictor_file_path = os.path.join(output_dir, \"ReinforcedModel.kdic\")\n",
1527+
"output_data_table_path = os.path.join(output_dir, \"ReinforcedAccidents.txt\")\n",
1528+
"\n",
1529+
"# Train the predictor (see train_predictor_mt for details)\n",
1530+
"_, model_dictionary_file_path = kh.train_predictor(\n",
1531+
" dictionary_file_path,\n",
1532+
" \"Accident\",\n",
1533+
" accidents_table_path,\n",
1534+
" \"Gravity\",\n",
1535+
" report_file_path,\n",
1536+
" additional_data_tables={\"Vehicles\": vehicles_table_path},\n",
1537+
" max_trees=0,\n",
1538+
")\n",
1539+
"\n",
1540+
"# Reinforce the predictor\n",
1541+
"kh.reinforce_predictor(\n",
1542+
" model_dictionary_file_path,\n",
1543+
" \"SNB_Accident\",\n",
1544+
" reinforced_predictor_file_path,\n",
1545+
" reinforcement_target_value=\"NonLethal\",\n",
1546+
" reinforcement_lever_variables=[\"InAgglomeration\", \"CollisionType\"],\n",
1547+
")\n",
1548+
"\n",
1549+
"# Deploy the reinforced model on the database\n",
1550+
"# Besides the mandatory parameters, it is specified:\n",
1551+
"# - A python dictionary linking data paths to file paths for non-root tables\n",
1552+
"kh.deploy_model(\n",
1553+
" reinforced_predictor_file_path,\n",
1554+
" \"Reinforcement_SNB_Accident\",\n",
1555+
" accidents_table_path,\n",
1556+
" output_data_table_path,\n",
1557+
" additional_data_tables={\"Vehicles\": vehicles_table_path},\n",
1558+
")"
1559+
]
1560+
},
14571561
{
14581562
"cell_type": "markdown",
14591563
"metadata": {},

0 commit comments

Comments
 (0)