Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions marimo/_plugins/ui/_impl/dataframes/transforms/apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,43 +19,40 @@
)
from marimo._utils.assert_never import assert_never

_transform_type_to_handler_method = {
TransformType.COLUMN_CONVERSION: "handle_column_conversion",
TransformType.RENAME_COLUMN: "handle_rename_column",
TransformType.SORT_COLUMN: "handle_sort_column",
TransformType.FILTER_ROWS: "handle_filter_rows",
TransformType.GROUP_BY: "handle_group_by",
TransformType.AGGREGATE: "handle_aggregate",
TransformType.SELECT_COLUMNS: "handle_select_columns",
TransformType.SHUFFLE_ROWS: "handle_shuffle_rows",
TransformType.SAMPLE_ROWS: "handle_sample_rows",
TransformType.EXPLODE_COLUMNS: "handle_explode_columns",
TransformType.EXPAND_DICT: "handle_expand_dict",
TransformType.UNIQUE: "handle_unique",
}

T = TypeVar("T")


def _handle(df: T, handler: TransformHandler[T], transform: Transform) -> T:
if transform.type is TransformType.COLUMN_CONVERSION:
return handler.handle_column_conversion(df, transform)
if transform.type is TransformType.RENAME_COLUMN:
return handler.handle_rename_column(df, transform)
if transform.type is TransformType.SORT_COLUMN:
return handler.handle_sort_column(df, transform)
if transform.type is TransformType.FILTER_ROWS:
return handler.handle_filter_rows(df, transform)
if transform.type is TransformType.GROUP_BY:
return handler.handle_group_by(df, transform)
if transform.type is TransformType.AGGREGATE:
return handler.handle_aggregate(df, transform)
if transform.type is TransformType.SELECT_COLUMNS:
return handler.handle_select_columns(df, transform)
if transform.type is TransformType.SHUFFLE_ROWS:
return handler.handle_shuffle_rows(df, transform)
if transform.type is TransformType.SAMPLE_ROWS:
return handler.handle_sample_rows(df, transform)
if transform.type is TransformType.EXPLODE_COLUMNS:
return handler.handle_explode_columns(df, transform)
if transform.type is TransformType.EXPAND_DICT:
return handler.handle_expand_dict(df, transform)
if transform.type is TransformType.UNIQUE:
return handler.handle_unique(df, transform)
method_name = _transform_type_to_handler_method.get(transform.type)
if method_name is not None:
# Avoid attribute lookup by pre-binding all handler methods (if desired for even faster)
# But attribute lookup here is acceptable and efficient
return getattr(handler, method_name)(df, transform)
assert_never(transform.type)


def _apply_transforms(
df: T, handler: TransformHandler[T], transforms: Transformations
) -> T:
if not transforms.transforms:
transforms_list = transforms.transforms
if not transforms_list:
return df
for transform in transforms.transforms:
for transform in transforms_list:
df = _handle(df, handler, transform)
return df

Expand Down