Skip to content
Open
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from google.cloud.bigquery import enums
from google.cloud.bigquery_storage_v1 import types as gapic_types
from google.cloud.bigquery_storage_v1.writer import AppendRowsStream
import pandas as pd

import pandas as pd
import pyarrow as pa

TABLE_LENGTH = 100_000
Expand Down Expand Up @@ -100,7 +100,10 @@ def make_table(project_id, dataset_id, bq_client):


def create_stream(bqstorage_write_client, table):
stream_name = f"projects/{table.project}/datasets/{table.dataset_id}/tables/{table.table_id}/_default"
stream_name = (
f"projects/{table.project}/datasets/{table.dataset_id}/"
f"tables/{table.table_id}/_default"
)
request_template = gapic_types.AppendRowsRequest()
request_template.write_stream = stream_name

Expand Down Expand Up @@ -160,18 +163,50 @@ def generate_pyarrow_table(num_rows=TABLE_LENGTH):


def generate_write_requests(pyarrow_table):
# Determine max_chunksize of the record batches. Because max size of
# AppendRowsRequest is 10 MB, we need to split the table if it's too big.
# See: https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1#appendrowsrequest
max_request_bytes = 10 * 2**20 # 10 MB
chunk_num = int(pyarrow_table.nbytes / max_request_bytes) + 1
chunk_size = int(pyarrow_table.num_rows / chunk_num)

# Construct request(s).
for batch in pyarrow_table.to_batches(max_chunksize=chunk_size):
# Maximum size for a single AppendRowsRequest is 10 MB.
# To be safe, we'll aim for a soft limit of 7 MB.
max_request_bytes = 7 * 1024 * 1024 # 7 MB

def _create_request(batches):
"""Helper to create an AppendRowsRequest from a list of batches."""
combined_table = pa.Table.from_batches(batches)
request = gapic_types.AppendRowsRequest()
request.arrow_rows.rows.serialized_record_batch = batch.serialize().to_pybytes()
yield request
request.arrow_rows.rows.serialized_record_batch = (
combined_table.combine_chunks().to_batches()[0].serialize().to_pybytes()
)
return request

batches_in_request = []
current_size = 0

# Split table into batches of one row.
for row_batch in pyarrow_table.to_batches(max_chunksize=1):
serialized_batch = row_batch.serialize().to_pybytes()
batch_size = len(serialized_batch)

if batch_size > max_request_bytes:
raise ValueError(
(
"A single PyArrow batch of one row is larger than the "
f"maximum request size (batch size: {batch_size} > "
f"max request size: {max_request_bytes}). Cannot proceed."
)
)

if current_size + batch_size > max_request_bytes and batches_in_request:
# Combine collected batches and yield request
yield _create_request(batches_in_request)

# Reset for next request.
batches_in_request = []
current_size = 0

batches_in_request.append(row_batch)
current_size += batch_size

# Yield any remaining batches
if batches_in_request:
yield _create_request(batches_in_request)


def verify_result(client, table, futures):
Expand All @@ -188,7 +223,7 @@ def verify_result(client, table, futures):
assert query_result.iloc[0, 0] >= TABLE_LENGTH

# Verify that table was split into multiple requests.
assert len(futures) == 2
assert len(futures) == 21


def main(project_id, dataset):
Expand Down
Loading