Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions paimon-python/pypaimon/tests/ray_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,59 @@ def test_basic_ray_data_read(self):
['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
"Name column should match"
)

def test_basic_ray_data_read(self):
"""Test basic Ray Data write from PyPaimon table."""
# Create schema
pa_schema = pa.schema([
('id', pa.int32()),
('name', pa.string()),
('value', pa.int64()),
])

schema = Schema.from_pyarrow_schema(pa_schema)
self.catalog.create_table('default.test_ray_basic', schema, False)
table = self.catalog.get_table('default.test_ray_basic')

# Write test data
test_data = pa.Table.from_pydict({
'id': [1, 2, 3, 4, 5],
'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
'value': [100, 200, 300, 400, 500],
}, schema=pa_schema)

from ray.data.read_api import from_arrow
ds = from_arrow(test_data)
write_builder = table.new_batch_write_builder()
writer = write_builder.new_write()
writer.write_raydata(ds, parallelism=2)
# Read using Ray Data
read_builder = table.new_read_builder()
table_read = read_builder.new_read()
table_scan = read_builder.new_scan()
splits = table_scan.plan().splits()

arrow_result = table_read.to_arrow(splits)

# Verify Ray dataset
self.assertIsNotNone(arrow_result, "Ray dataset should not be None")
self.assertEqual(arrow_result.count(), 5, "Should have 5 rows")

# Test basic operations
sample_data = arrow_result.take(3)
self.assertEqual(len(sample_data), 3, "Should have 3 sample rows")

# Convert to pandas for verification
df = arrow_result.to_pandas()
self.assertEqual(len(df), 5, "DataFrame should have 5 rows")
# Sort by id to ensure order-independent comparison
df_sorted = df.sort_values(by='id').reset_index(drop=True)
self.assertEqual(list(df_sorted['id']), [1, 2, 3, 4, 5], "ID column should match")
self.assertEqual(
list(df_sorted['name']),
['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
"Name column should match"
)

def test_ray_data_with_predicate(self):
"""Test Ray Data read with predicate filtering."""
Expand Down
85 changes: 85 additions & 0 deletions paimon-python/pypaimon/write/ray_datasink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Module to reawrited a Paimon table from a Ray Dataset, by using the Ray Datasink API.
"""

from typing import Iterable
from ray.data.datasource.datasink import Datasink, WriteResult, WriteReturnType
from pypaimon.table.table import Table
from pypaimon.write.write_builder import WriteBuilder
from ray.data.block import BlockAccessor
from ray.data.block import Block
from ray.data.dataset import Dataset
from ray.data._internal.execution.interfaces import TaskContext
import pyarrow as pa

class PaimonDatasink(Datasink):

def __init__(self, table: Table, overwrite=False):
self.table = table
self.overwrite = overwrite

def on_write_start(self) -> None:
"""Callback for when a write job starts.

Use this method to perform setup for write tasks. For example, creating a
staging bucket in S3.
"""
self.writer_builder: WriteBuilder= self.table.new_batch_write_builder()
if self.overwrite:
self.writer_builder = self.writer_builder.overwrite()

def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> WriteReturnType:
"""Write blocks. This is used by a single write task.

Args:
blocks: Generator of data blocks.
ctx: ``TaskContext`` for the write task.

Returns:
Result of this write task. When the entire write operator finishes,
All returned values will be passed as `WriteResult.write_returns`
to `Datasink.on_write_complete`.
"""
table_write = self.writer_builder.new_write()
for block in blocks:
block_arrow: pa.Table = BlockAccessor.for_block(block).to_arrow()
table_write.write_arrow(block_arrow)
commit_messages = table_write.prepare_commit()
table_write.close()
return commit_messages

def on_write_complete(self, write_result: WriteResult[WriteReturnType]):
"""Callback for when a write job completes.

This can be used to `commit` a write output. This method must
succeed prior to ``write_datasink()`` returning to the user. If this
method fails, then ``on_write_failed()`` is called.

Args:
write_result: Aggregated result of the
Write operator, containing write results and stats.
"""
table_commit = self.writer_builder.new_commit()
table_commit.commit([commit_message for commit_messages in write_result.write_returns for commit_message in commit_messages])
table_commit.close()

5 changes: 5 additions & 0 deletions paimon-python/pypaimon/write/table_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def with_write_type(self, write_cols: List[str]):
self.file_store_write.write_cols = write_cols
return self

def write_raydata(self, dataset, overwrite=False, parallelism=1):
from pypaimon.write.ray_datasink import PaimonDatasink
datasink = PaimonDatasink(dataset, overwrite=overwrite)
dataset.write_datasink(datasink, concurrency=parallelism)

def close(self):
self.file_store_write.close()

Expand Down
Loading