Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 26 additions & 26 deletions daft_lance/_lance.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .lance_compaction import compact_files_internal
from .lance_merge_column import merge_columns_from_df, merge_columns_internal
from .lance_scalar_index import create_scalar_index_internal
from .lance_scan import LanceDBScanOperator
from .lance_scan import LanceScanOperator
from .utils import construct_lance_dataset

if TYPE_CHECKING:
Expand All @@ -43,12 +43,12 @@ def read_lance(
include_fragment_id: bool | None = None,
checkpoint: CheckpointConfig | None = None,
) -> DataFrame:
"""Create a DataFrame from a LanceDB table.
"""Create a DataFrame from a Lance dataset.

Args:
uri: The URI of the Lance table to read from. Accepts a local path or an
object-store URI like "s3://bucket/path".
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
io_config: A custom IOConfig to use when accessing Lance data. Defaults to None.
version : optional, int | str
If specified, load a specific version of the Lance dataset. Else, loads the
latest version. A version number (`int`) or a tag (`str`) can be provided.
Expand Down Expand Up @@ -99,25 +99,25 @@ def read_lance(
already exists in the store are skipped on re-run. Requires the Ray runner.

Returns:
DataFrame: a DataFrame with the schema converted from the specified LanceDB table
DataFrame: a DataFrame with the schema converted from the specified Lance dataset

This function requires the use of [LanceDB](https://lancedb.github.io/lancedb/), which is the Python library for the LanceDB project.
This function reads Lance datasets via the Lance Python package.
To ensure that this is installed with Daft, you may install: `pip install daft[lance]`

Examples:
Read a local LanceDB table:
Read a local Lance dataset:
>>> df = daft.read_lance("/path/to/lance/data/")
>>> df.show()

Read a LanceDB table and specify a version:
Read a Lance dataset and specify a version:
>>> df = daft.read_lance("/path/to/lance/data/", version=1)
>>> df.show()

Read a LanceDB table with fragment grouping:
Read a Lance dataset with fragment grouping:
>>> df = daft.read_lance("/path/to/lance/data/", fragment_group_size=5)
>>> df.show()

Read a LanceDB table from a public S3 bucket:
Read a Lance dataset from a public S3 bucket:
>>> from daft.io import S3Config, IOConfig
>>> io_config = IOConfig(s3=S3Config(region="us-west-2", anonymous=True))
>>> df = daft.read_lance("s3://daft-oss-public-data/lance/words-test-dataset", io_config=io_config)
Expand Down Expand Up @@ -147,7 +147,7 @@ def read_lance(
metadata_cache_size_bytes=metadata_cache_size_bytes,
)

lance_operator = LanceDBScanOperator(
lance_operator = LanceScanOperator(
ds,
fragment_group_size=fragment_group_size,
include_fragment_id=include_fragment_id,
Expand Down Expand Up @@ -178,14 +178,14 @@ def merge_columns(
default_scan_options: dict[str, Any] | None = None,
metadata_cache_size_bytes: int | None = None,
) -> LanceDataset:
"""Merge new columns into a LanceDB table using a transformation function.
"""Merge new columns into a Lance dataset using a transformation function.

This function modifies the LanceDB table in-place by adding new columns computed
This function modifies the Lance dataset in-place by adding new columns computed
from existing data using a transformation function. It does not return a DataFrame.

Args:
uri: The URI of the Lance table (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
io_config: A custom IOConfig to use when accessing Lance data. Defaults to None.
transform: A transformation function or UDF to apply to the data.
read_columns: List of column names to read for the transformation.
reader_schema: Schema for the reader.
Expand All @@ -204,17 +204,17 @@ def merge_columns(
None: This function modifies the table in-place and does not return a value.

Note:
This function requires the use of [LanceDB](https://lancedb.github.io/lancedb/), which is the Python library for the LanceDB project.
This function writes Lance datasets via the Lance Python package.
To ensure that this is installed with Daft, you may install: `pip install daft[lance]`

Examples:
Merge new columns into a local LanceDB table:
Merge new columns into a local Lance dataset:
>>> def double_score(batch):
... # Example transformation function
... import pyarrow.compute as pc
...
... return batch.append_column("new_column", pc.multiply(batch["c"], 2))
>>> daft_lance.merge_columns("s3://my-lancedb-bucket/data/", transform=double_score)
>>> daft_lance.merge_columns("s3://my-lance-bucket/data/", transform=double_score)
"""
if transform is None:
raise ValueError(
Expand Down Expand Up @@ -273,13 +273,13 @@ def merge_columns_df(
) -> Any:
"""Row-level merge columns entrypoint using a DataFrame.

This function modifies the LanceDB table in-place by merging new columns from a DataFrame
This function modifies the Lance dataset in-place by merging new columns from a DataFrame
into existing fragments using a row-level join. It does not return a DataFrame.

Args:
df: DataFrame containing the new columns to merge along with fragment_id and join key columns
uri: URL to the LanceDB table (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
uri: URL to the Lance dataset (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing Lance data. Defaults to None.
read_columns: List of column names to read for the transformation.
reader_schema: Schema for the reader.
storage_options: Extra options for storage connection.
Expand All @@ -300,22 +300,22 @@ def merge_columns_df(
None: This function modifies the table in-place and does not return a value.

Note:
This function requires the use of [LanceDB](https://lancedb.github.io/lancedb/), which is the Python library for the LanceDB project.
This function writes Lance datasets via the Lance Python package.
To ensure that this is installed with Daft, you may install: `pip install daft[lance]`

Examples:
Merge new columns into a local LanceDB table:
Merge new columns into a local Lance dataset:
>>> import daft
>>> # Read the existing table with row addresses
>>> df = daft.read_lance(
... "s3://my-lancedb-bucket/data/",
... "s3://my-lance-bucket/data/",
... default_scan_options={"with_row_address": True},
... include_fragment_id=True,
... )
>>> # Add new columns based on existing data
>>> df = df.with_column("doubled_c", df["c"] * 2)
>>> # Merge the new columns back to the table
>>> daft_lance.merge_columns_df(df, "s3://my-lancedb-bucket/data/")
>>> daft_lance.merge_columns_df(df, "s3://my-lance-bucket/data/")
"""
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config
storage_options = storage_options or io_config_to_storage_options(io_config, uri)
Expand Down Expand Up @@ -385,7 +385,7 @@ def create_scalar_index(

Args:
uri: The URI of the Lance table (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
io_config: A custom IOConfig to use when accessing Lance data. Defaults to None.
column: Column name to index
index_type: Type of index to build.
For distributed execution this supports "INVERTED", "FTS", and "BTREE".
Expand Down Expand Up @@ -424,7 +424,7 @@ def create_scalar_index(
ImportError: If lance package is not available

Note:
This function requires the use of [LanceDB](https://lancedb.github.io/lancedb/), which is the Python library for the LanceDB project.
This function writes Lance datasets via the Lance Python package.
To ensure that this is installed with Daft, you may install: `pip install daft[lance]`

Examples:
Expand Down Expand Up @@ -510,7 +510,7 @@ def compact_files(

Args:
uri: The URI of the Lance table (supports remote URLs to object stores such as `s3://` or `gs://`)
io_config: A custom IOConfig to use when accessing LanceDB data. Defaults to None.
io_config: A custom IOConfig to use when accessing Lance data. Defaults to None.
storage_options: Extra options for storage connection.
version: If specified, load a specific version of the Lance dataset.
asof: If specified, find the latest version created on or earlier than the given argument value.
Expand Down
71 changes: 59 additions & 12 deletions daft_lance/lance_scan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterator
from typing import Any

Expand All @@ -23,7 +24,7 @@


# TODO support fts and fast_search
def _lancedb_table_factory_function(
def _lance_table_factory_function(
ds_uri: str,
open_kwargs: dict[Any, Any] | None = None,
fragment_ids: list[int] | None = None,
Expand Down Expand Up @@ -110,13 +111,13 @@ def _batches() -> Iterator[PyRecordBatch]:
return _iter_batches()


def _lancedb_count_result_function(
def _lance_count_result_function(
ds_uri: str,
open_kwargs: dict[Any, Any] | None,
required_column: str,
filter: pa.compute.Expression | None = None,
) -> Iterator[PyRecordBatch]:
"""Use LanceDB's API to count rows and return a record batch with the count result."""
"""Use Lance's API to count rows and return a record batch with the count result."""
ds = lance.dataset(ds_uri, **(open_kwargs or {}))
logger.debug("Using metadata for counting all rows")
count = ds.count_rows(filter=filter)
Expand All @@ -128,7 +129,27 @@ def _lancedb_count_result_function(
yield result_batch._recordbatch


class LanceDBScanOperator(ScanOperator, SupportsPushdownFilters):
def _lancedb_table_factory_function(*args: Any, **kwargs: Any) -> Iterator[PyRecordBatch]:
warnings.warn(
"_lancedb_table_factory_function is deprecated and will be removed in a future release. "
"Use _lance_table_factory_function instead.",
DeprecationWarning,
stacklevel=2,
)
return _lance_table_factory_function(*args, **kwargs)


def _lancedb_count_result_function(*args: Any, **kwargs: Any) -> Iterator[PyRecordBatch]:
warnings.warn(
"_lancedb_count_result_function is deprecated and will be removed in a future release. "
"Use _lance_count_result_function instead.",
DeprecationWarning,
stacklevel=2,
)
return _lance_count_result_function(*args, **kwargs)


class LanceScanOperator(ScanOperator, SupportsPushdownFilters):
def __init__(
self,
ds: lance.LanceDataset,
Expand All @@ -147,10 +168,10 @@ def __init__(
self._schema = convert_lance_schema(base)

def name(self) -> str:
return "LanceDBScanOperator"
return "LanceScanOperator"

def display_name(self) -> str:
return f"LanceDBScanOperator({self._ds.uri})"
return f"LanceScanOperator({self._ds.uri})"

def schema(self) -> Schema:
return self._schema
Expand Down Expand Up @@ -254,8 +275,8 @@ def _create_count_rows_scan_task(self, pushdowns: PyPushdowns) -> Iterator[ScanT
new_schema = Schema.from_pyarrow_schema(pa.schema([pa.field(fields[0], pa.uint64())]))
open_kwargs = getattr(self._ds, "_lance_open_kwargs", None)
yield ScanTask.python_factory_func_scan_task(
module=_lancedb_count_result_function.__module__,
func_name=_lancedb_count_result_function.__name__,
module=_lance_count_result_function.__module__,
func_name=_lance_count_result_function.__name__,
func_args=(self._ds.uri, open_kwargs, fields[0], self._combine_filters_to_arrow()),
schema=new_schema._schema,
num_rows=1,
Expand Down Expand Up @@ -293,8 +314,8 @@ def _create_scan_tasks_with_limit_and_no_filters(

task_schema = self._schema
yield ScanTask.python_factory_func_scan_task(
module=_lancedb_table_factory_function.__module__,
func_name=_lancedb_table_factory_function.__name__,
module=_lance_table_factory_function.__module__,
func_name=_lance_table_factory_function.__name__,
func_args=(
self._ds.uri,
open_kwargs,
Expand Down Expand Up @@ -328,8 +349,8 @@ def _python_factory_func_scan_task(
size_bytes: int | None = None,
) -> ScanTask:
return ScanTask.python_factory_func_scan_task(
module=_lancedb_table_factory_function.__module__,
func_name=_lancedb_table_factory_function.__name__,
module=_lance_table_factory_function.__module__,
func_name=_lance_table_factory_function.__name__,
func_args=(
self._ds.uri,
open_kwargs,
Expand Down Expand Up @@ -486,3 +507,29 @@ def _estimate_size_bytes(fragment: lance.LanceFragment) -> int:
return 0

return sum(file.file_size_bytes for file in fragment.metadata.files if file.file_size_bytes is not None)


class LanceDBScanOperator(LanceScanOperator):
def __init__(
self,
ds: lance.LanceDataset,
fragment_group_size: int | None = None,
include_fragment_id: bool | None = False,
):
warnings.warn(
"LanceDBScanOperator is deprecated and will be removed in a future release. "
"Use LanceScanOperator instead.",
DeprecationWarning,
stacklevel=2,
)
super().__init__(
ds,
fragment_group_size=fragment_group_size,
include_fragment_id=include_fragment_id,
)

def name(self) -> str:
return "LanceDBScanOperator"

def display_name(self) -> str:
return f"LanceDBScanOperator({self._ds.uri})"
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
from daft import col
from daft.daft import CountMode
from daft.recordbatch import RecordBatch
from daft_lance.lance_scan import LanceDBScanOperator, _lancedb_count_result_function
from daft_lance.lance_scan import (
LanceDBScanOperator,
LanceScanOperator,
_lance_count_result_function,
_lancedb_count_result_function,
)


class TestLanceCountResultFunction:
Expand All @@ -30,29 +35,41 @@ def test_dataset_path(self, tmp_path_factory):
lance.write_dataset(pa.Table.from_pydict(test_data), tmp_dir)
yield str(tmp_dir)

def test_lancedb_count_no_filters_direct_call(self, test_dataset_path):
def test_lance_count_no_filters_direct_call(self, test_dataset_path):
"""Test that no filters list is handled correctly."""
ds = lance.dataset(test_dataset_path)
result_generator = _lancedb_count_result_function(ds.uri, None, "count")
result_generator = _lance_count_result_function(ds.uri, None, "count")
result_batch = next(result_generator)
record_batch = RecordBatch._from_pyrecordbatch(result_batch)
result_dict = record_batch.to_pydict()
assert result_dict["count"][0] == 6

def test_lancedb_count_with_filters_path(self, test_dataset_path):
def test_lance_count_with_filters_path(self, test_dataset_path):
"""Test that filters list is handled correctly."""
ds = lance.dataset(test_dataset_path)
filter_expr = pc.greater(pc.field("age"), pc.scalar(30))
result_generator = _lancedb_count_result_function(ds.uri, None, "count", filter=filter_expr)
result_generator = _lance_count_result_function(ds.uri, None, "count", filter=filter_expr)
result_batch = next(result_generator)
record_batch = RecordBatch._from_pyrecordbatch(result_batch)
result_dict = record_batch.to_pydict()
assert result_dict["count"][0] == 4

def test_deprecated_lancedb_count_result_function_alias(self, test_dataset_path):
"""Test that the old count helper name remains available."""
ds = lance.dataset(test_dataset_path)

with pytest.deprecated_call(match="_lancedb_count_result_function is deprecated"):
result_generator = _lancedb_count_result_function(ds.uri, None, "count")

result_batch = next(result_generator)
record_batch = RecordBatch._from_pyrecordbatch(result_batch)
result_dict = record_batch.to_pydict()
assert result_dict["count"][0] == 6

def test_unsupported_count_mode_fallback(self, test_dataset_path):
"""Test that unsupported count mode falls back to regular scan."""
ds = lance.dataset(test_dataset_path)
scan_op = LanceDBScanOperator(ds)
scan_op = LanceScanOperator(ds)

with patch.object(scan_op, "supported_count_modes", return_value=[CountMode.All]):
with patch("daft_lance.lance_scan.logger") as mock_logger:
Expand All @@ -77,13 +94,24 @@ def test_unsupported_count_mode_fallback(self, test_dataset_path):
def test_empty_filters_list_handling(self, test_dataset_path):
"""Test that empty filters list is handled correctly."""
ds = lance.dataset(test_dataset_path)
scan_op = LanceDBScanOperator(ds)
scan_op = LanceScanOperator(ds)
pushed, remaining = scan_op.push_filters([])

assert len(pushed) == 0
assert len(remaining) == 0
assert scan_op._pushed_filters is None

def test_deprecated_lancedb_scan_operator_alias(self, test_dataset_path):
"""Test that the old public scan operator name remains available."""
ds = lance.dataset(test_dataset_path)

with pytest.deprecated_call(match="LanceDBScanOperator is deprecated"):
scan_op = LanceDBScanOperator(ds)

assert isinstance(scan_op, LanceScanOperator)
assert scan_op.name() == "LanceDBScanOperator"
assert scan_op.display_name() == f"LanceDBScanOperator({ds.uri})"

def test_very_large_filter_expression(self, test_dataset_path):
"""Test that very large filter expressions are handled correctly."""
df = daft.read_lance(test_dataset_path)
Expand Down
Loading