Skip to content

Commit 1012a6f

Browse files
Merge pull request #314 from KhiopsML/312-accept-types-in-internal-csv-reading-function
312 accept types in internal csv reading function
2 parents d093ae8 + 3bf819f commit 1012a6f

File tree

7 files changed

+233
-140
lines changed

7 files changed

+233
-140
lines changed

doc/samples/samples_sklearn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ Samples
723723
keep_initial_variables=True,
724724
transform_type_categorical="part_id",
725725
transform_type_numerical="part_id",
726-
transform_pairs="part_id",
726+
transform_type_pairs="part_id",
727727
)
728728
khe.fit(X, y)
729729

khiops/samples/samples_sklearn.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -865,7 +865,7 @@
865865
" keep_initial_variables=True,\n",
866866
" transform_type_categorical=\"part_id\",\n",
867867
" transform_type_numerical=\"part_id\",\n",
868-
" transform_pairs=\"part_id\",\n",
868+
" transform_type_pairs=\"part_id\",\n",
869869
")\n",
870870
"khe.fit(X, y)\n",
871871
"\n",

khiops/samples/samples_sklearn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,7 @@ def khiops_encoder_with_hyperparameters():
761761
keep_initial_variables=True,
762762
transform_type_categorical="part_id",
763763
transform_type_numerical="part_id",
764-
transform_pairs="part_id",
764+
transform_type_pairs="part_id",
765765
)
766766
khe.fit(X, y)
767767

khiops/sklearn/dataset.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ def get_khiops_variable_name(column_id):
346346
return variable_name
347347

348348

349-
def read_internal_data_table(file_path_or_stream):
349+
def read_internal_data_table(file_path_or_stream, column_dtypes=None):
350350
"""Reads into a DataFrame a data table file with the internal format settings
351351
352352
The table is read with the following settings:
@@ -357,18 +357,34 @@ def read_internal_data_table(file_path_or_stream):
357357
- Use `csv.QUOTE_MINIMAL`
358358
- double quoting enabled (quotes within quotes can be escaped with '""')
359359
- UTF-8 encoding
360+
- User-specified dtypes (optional)
360361
361362
Parameters
362363
----------
363364
file_path_or_stream : str or file object
364365
The path of the internal data table file to be read or a readable file
365366
object.
367+
column_dtypes : dict, optional
368+
Dictionary linking column names with dtypes. See ``dtype`` parameter of the
369+
`pandas.read_csv` function. If not set, then the column types are detected
370+
automatically by pandas.
366371
367372
Returns
368373
-------
369374
`pandas.DataFrame`
370-
The dataframe representation.
375+
The dataframe representation of the data table.
371376
"""
377+
# Change the 'U' types (Unicode strings) to 'O' because pandas does not support them
378+
# in read_csv
379+
if column_dtypes is not None:
380+
execution_column_dtypes = {}
381+
for column_name, dtype in column_dtypes.items():
382+
if hasattr(dtype, "kind") and dtype.kind == "U":
383+
execution_column_dtypes[column_name] = np.dtype("O")
384+
else:
385+
execution_column_dtypes = None
386+
387+
# Read and return the dataframe
372388
return pd.read_csv(
373389
file_path_or_stream,
374390
sep="\t",
@@ -377,6 +393,7 @@ def read_internal_data_table(file_path_or_stream):
377393
quoting=csv.QUOTE_MINIMAL,
378394
doublequote=True,
379395
encoding="utf-8",
396+
dtype=execution_column_dtypes,
380397
)
381398

382399

@@ -1132,6 +1149,11 @@ def __repr__(self):
11321149
f"dtypes={dtypes_str}>"
11331150
)
11341151

1152+
def get_column_dtype(self, column_id):
1153+
if column_id not in self.data_source.dtypes:
1154+
raise KeyError(f"Column '{column_id}' not found in the dtypes field")
1155+
return self.data_source.dtypes[column_id]
1156+
11351157
def create_table_file_for_khiops(
11361158
self, output_dir, sort=True, target_column=None, target_column_id=None
11371159
):
@@ -1214,6 +1236,9 @@ def __repr__(self):
12141236
f"dtype={dtype_str}; target={self.target_column_id}>"
12151237
)
12161238

1239+
def get_column_dtype(self, _):
1240+
return self.data_source.dtype
1241+
12171242
def create_table_file_for_khiops(
12181243
self, output_dir, sort=True, target_column=None, target_column_id=None
12191244
):
@@ -1300,6 +1325,9 @@ def __repr__(self):
13001325
f"dtype={dtype_str}>"
13011326
)
13021327

1328+
def get_column_dtype(self, _):
1329+
return self.data_source.dtype
1330+
13031331
def create_khiops_dictionary(self):
13041332
"""Creates a Khiops dictionary representing this sparse table
13051333

0 commit comments

Comments
 (0)