diff --git a/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py b/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py index ca469c6f909..5a30640f614 100644 --- a/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py +++ b/marimo/_plugins/ui/_impl/dataframes/transforms/apply.py @@ -19,43 +19,40 @@ ) from marimo._utils.assert_never import assert_never +_transform_type_to_handler_method = { + TransformType.COLUMN_CONVERSION: "handle_column_conversion", + TransformType.RENAME_COLUMN: "handle_rename_column", + TransformType.SORT_COLUMN: "handle_sort_column", + TransformType.FILTER_ROWS: "handle_filter_rows", + TransformType.GROUP_BY: "handle_group_by", + TransformType.AGGREGATE: "handle_aggregate", + TransformType.SELECT_COLUMNS: "handle_select_columns", + TransformType.SHUFFLE_ROWS: "handle_shuffle_rows", + TransformType.SAMPLE_ROWS: "handle_sample_rows", + TransformType.EXPLODE_COLUMNS: "handle_explode_columns", + TransformType.EXPAND_DICT: "handle_expand_dict", + TransformType.UNIQUE: "handle_unique", +} + T = TypeVar("T") def _handle(df: T, handler: TransformHandler[T], transform: Transform) -> T: - if transform.type is TransformType.COLUMN_CONVERSION: - return handler.handle_column_conversion(df, transform) - if transform.type is TransformType.RENAME_COLUMN: - return handler.handle_rename_column(df, transform) - if transform.type is TransformType.SORT_COLUMN: - return handler.handle_sort_column(df, transform) - if transform.type is TransformType.FILTER_ROWS: - return handler.handle_filter_rows(df, transform) - if transform.type is TransformType.GROUP_BY: - return handler.handle_group_by(df, transform) - if transform.type is TransformType.AGGREGATE: - return handler.handle_aggregate(df, transform) - if transform.type is TransformType.SELECT_COLUMNS: - return handler.handle_select_columns(df, transform) - if transform.type is TransformType.SHUFFLE_ROWS: - return handler.handle_shuffle_rows(df, transform) - if transform.type is TransformType.SAMPLE_ROWS: - return handler.handle_sample_rows(df, transform) - if transform.type is TransformType.EXPLODE_COLUMNS: - return handler.handle_explode_columns(df, transform) - if transform.type is TransformType.EXPAND_DICT: - return handler.handle_expand_dict(df, transform) - if transform.type is TransformType.UNIQUE: - return handler.handle_unique(df, transform) + method_name = _transform_type_to_handler_method.get(transform.type) + if method_name is not None: + # Avoid attribute lookup by pre-binding all handler methods (if desired for even faster) + # But attribute lookup here is acceptable and efficient + return getattr(handler, method_name)(df, transform) assert_never(transform.type) def _apply_transforms( df: T, handler: TransformHandler[T], transforms: Transformations ) -> T: - if not transforms.transforms: + transforms_list = transforms.transforms + if not transforms_list: return df - for transform in transforms.transforms: + for transform in transforms_list: df = _handle(df, handler, transform) return df