Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ env:
WEAVIATE_133: 1.33.11
WEAVIATE_134: 1.34.8
WEAVIATE_135: 1.35.2
WEAVIATE_136: 1.36.0-dev-c8f578d
WEAVIATE_136: 1.36.0-dev-0bbf31a

jobs:
lint-and-format:
Expand Down
4 changes: 2 additions & 2 deletions integration/test_batch_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def test_add_ten_thousand_data_objects(
"""Test adding ten thousand data objects."""
client, name = client_factory()
if (
request.node.callspec.id == "test_add_ten_thousand_data_objects_experimental"
request.node.callspec.id == "test_add_ten_thousand_data_objects_stream"
and client._connection._weaviate_version.is_lower_than(1, 36, 0)
):
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
Expand Down Expand Up @@ -641,7 +641,7 @@ def test_add_one_object_and_a_self_reference(
"""Test adding one object and a self reference."""
client, name = client_factory()
if (
request.node.callspec.id == "test_add_one_object_and_a_self_reference_experimental"
request.node.callspec.id == "test_add_one_object_and_a_self_reference_stream"
and client._connection._weaviate_version.is_lower_than(1, 36, 0)
):
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
Expand Down
45 changes: 44 additions & 1 deletion integration/test_collection_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) -


@pytest.mark.asyncio
async def test_add_one_hundred_thousand_objects_async_collection(
async def test_batch_one_hundred_thousand_objects_async_collection(
batch_collection_async: BatchCollectionAsync,
) -> None:
"""Test adding one hundred thousand data objects."""
Expand All @@ -295,3 +295,46 @@ async def test_add_one_hundred_thousand_objects_async_collection(
assert await col.length() == nr_objects
assert col.batch.results.objs.has_errors is False
assert len(col.batch.failed_objects) == 0, [obj.message for obj in col.batch.failed_objects]


@pytest.mark.asyncio
async def test_ingest_one_hundred_thousand_data_objects_async(
batch_collection_async: BatchCollectionAsync,
) -> None:
col = await batch_collection_async()
if col._connection._weaviate_version.is_lower_than(1, 36, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
nr_objects = 100000
import time

start = time.time()
results = await col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects))
end = time.time()
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
assert len(results.errors) == 0
assert len(results.all_responses) == nr_objects
assert len(results.uuids) == nr_objects
assert await col.length() == nr_objects
assert results.has_errors is False
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]


def test_ingest_one_hundred_thousand_data_objects(
batch_collection: BatchCollection,
) -> None:
col = batch_collection()
if col._connection._weaviate_version.is_lower_than(1, 36, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
nr_objects = 100000
import time

start = time.time()
results = col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects))
end = time.time()
print(f"Time taken to add {nr_objects} objects: {end - start} seconds")
assert len(results.errors) == 0
assert len(results.all_responses) == nr_objects
assert len(results.uuids) == nr_objects
assert len(col) == nr_objects
assert results.has_errors is False
assert len(results.errors) == 0, [obj.message for obj in results.errors.values()]
4 changes: 2 additions & 2 deletions integration/test_rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,8 +742,8 @@ def test_server_side_batching_with_auth() -> None:
with connect_to_local(
port=RBAC_PORTS[0], grpc_port=RBAC_PORTS[1], auth_credentials=RBAC_AUTH_CREDS
) as client:
if client._connection._weaviate_version.is_lower_than(1, 34, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.34.0")
if client._connection._weaviate_version.is_lower_than(1, 36, 0):
pytest.skip("Server-side batching not supported in Weaviate < 1.36.0")
collection = client.collections.create(collection_name)
with client.batch.stream() as batch:
batch.add_object(collection_name)
Expand Down
3 changes: 3 additions & 0 deletions weaviate/collections/batch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,9 @@ def number_errors(self) -> int:
def _start(self):
pass

def _wait(self):
pass

def _shutdown(self) -> None:
"""Shutdown the current batch and wait for all requests to be finished."""
self.flush()
Expand Down
10 changes: 4 additions & 6 deletions weaviate/collections/batch/batch_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
_ClusterBatch,
_ClusterBatchAsync,
_DynamicBatching,
_ServerSideBatching,
)
from weaviate.collections.batch.sync import _BatchBaseSync
from weaviate.collections.classes.batch import (
Expand Down Expand Up @@ -140,8 +139,6 @@ def __init__(
self._connection = connection
self._consistency_level = consistency_level
self._current_batch: Optional[_BatchBaseAsync] = None
# config options
self._batch_mode: _BatchMode = _ServerSideBatching(1)

self._batch_data = _BatchDataWrapper()
self._cluster = _ClusterBatchAsync(connection)
Expand Down Expand Up @@ -371,7 +368,7 @@ async def add_reference(
"""
...

async def flush(self) -> None:
def flush(self) -> None:
"""Flush the current batch.

This will send all the objects and references in the current batch to Weaviate.
Expand Down Expand Up @@ -505,19 +502,20 @@ def number_errors(self) -> int:
Q = TypeVar("Q", bound=Union[BatchClientProtocolAsync, BatchCollectionProtocolAsync[Properties]])


class _ContextManagerWrapper(Generic[T, P]):
class _ContextManagerSync(Generic[T, P]):
def __init__(self, current_batch: T):
self.__current_batch: T = current_batch

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.__current_batch._shutdown()
self.__current_batch._wait()

def __enter__(self) -> P:
self.__current_batch._start()
return self.__current_batch # pyright: ignore[reportReturnType]


class _ContextManagerWrapperAsync(Generic[Q]):
class _ContextManagerAsync(Generic[Q]):
def __init__(self, current_batch: _BatchBaseAsync):
self.__current_batch = current_batch

Expand Down
20 changes: 10 additions & 10 deletions weaviate/collections/batch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
_BatchMode,
_BatchWrapper,
_BatchWrapperAsync,
_ContextManagerWrapper,
_ContextManagerWrapperAsync,
_ContextManagerAsync,
_ContextManagerSync,
)
from weaviate.collections.batch.sync import _BatchBaseSync
from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers
Expand Down Expand Up @@ -146,10 +146,10 @@ async def add_reference(
BatchClient = _BatchClient
BatchClientSync = _BatchClientSync
BatchClientAsync = _BatchClientAsync
ClientBatchingContextManager = _ContextManagerWrapper[
ClientBatchingContextManager = _ContextManagerSync[
Union[BatchClient, BatchClientSync], BatchClientProtocol
]
AsyncClientBatchingContextManager = _ContextManagerWrapperAsync[BatchClientProtocolAsync]
ClientBatchingContextManagerAsync = _ContextManagerAsync[BatchClientProtocolAsync]


class _BatchClientWrapper(_BatchWrapper):
Expand Down Expand Up @@ -196,7 +196,7 @@ def __create_batch_and_reset(

self._batch_data = _BatchDataWrapper() # clear old data

return _ContextManagerWrapper(
return _ContextManagerSync(
batch_client(
connection=self._connection,
consistency_level=self._consistency_level,
Expand Down Expand Up @@ -310,7 +310,7 @@ def __init__(

def __create_batch_and_reset(self):
self._batch_data = _BatchDataWrapper() # clear old data
return _ContextManagerWrapperAsync(
return _ContextManagerAsync(
BatchClientAsync(
connection=self._connection,
consistency_level=self._consistency_level,
Expand All @@ -328,15 +328,15 @@ def experimental(
*,
concurrency: Optional[int] = None,
consistency_level: Optional[ConsistencyLevel] = None,
) -> AsyncClientBatchingContextManager:
) -> ClientBatchingContextManagerAsync:
return self.stream(concurrency=concurrency, consistency_level=consistency_level)

def stream(
self,
*,
concurrency: Optional[int] = None,
consistency_level: Optional[ConsistencyLevel] = None,
) -> AsyncClientBatchingContextManager:
) -> ClientBatchingContextManagerAsync:
"""Configure the batching context manager to use batch streaming.

When you exit the context manager, the final batch will be sent automatically.
Expand All @@ -345,9 +345,9 @@ def stream(
concurrency: The number of concurrent streams to use when sending batches. If not provided, the default will be one.
consistency_level: The consistency level to be used when inserting data. If not provided, the default value is `None`.
"""
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
raise WeaviateUnsupportedFeatureError(
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
)
self._batch_mode = _ServerSideBatching(
# concurrency=concurrency
Expand Down
26 changes: 13 additions & 13 deletions weaviate/collections/batch/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
BatchCollectionProtocolAsync,
_BatchWrapper,
_BatchWrapperAsync,
_ContextManagerWrapper,
_ContextManagerWrapperAsync,
_ContextManagerAsync,
_ContextManagerSync,
)
from weaviate.collections.batch.sync import _BatchBaseSync
from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers
Expand Down Expand Up @@ -88,14 +88,14 @@ def add_reference(
class _BatchCollectionSync(Generic[Properties], _BatchBaseSync):
def __init__(
self,
executor: ThreadPoolExecutor,
connection: ConnectionSync,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
batch_mode: _BatchMode,
name: str,
tenant: Optional[str],
vectorizer_batching: bool,
executor: Optional[ThreadPoolExecutor] = None,
batch_mode: Optional[_BatchMode] = None,
vectorizer_batching: bool = False,
) -> None:
super().__init__(
connection=connection,
Expand Down Expand Up @@ -184,11 +184,11 @@ async def add_reference(
BatchCollection = _BatchCollection
BatchCollectionSync = _BatchCollectionSync
BatchCollectionAsync = _BatchCollectionAsync
CollectionBatchingContextManager = _ContextManagerWrapper[
CollectionBatchingContextManager = _ContextManagerSync[
Union[BatchCollection[Properties], BatchCollectionSync[Properties]],
BatchCollectionProtocol[Properties],
]
CollectionBatchingContextManagerAsync = _ContextManagerWrapperAsync[
CollectionBatchingContextManagerAsync = _ContextManagerAsync[
BatchCollectionProtocolAsync[Properties]
]

Expand Down Expand Up @@ -239,7 +239,7 @@ def __create_batch_and_reset(
self._vectorizer_batching = False

self._batch_data = _BatchDataWrapper() # clear old data
return _ContextManagerWrapper(
return _ContextManagerSync(
batch_client(
connection=self._connection,
consistency_level=self._consistency_level,
Expand Down Expand Up @@ -311,9 +311,9 @@ def stream(
concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests
made to Weaviate. If not provided, the default value is 1.
"""
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
raise WeaviateUnsupportedFeatureError(
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
)
self._batch_mode = _ServerSideBatching(
# concurrency=concurrency
Expand All @@ -338,7 +338,7 @@ def __init__(

def __create_batch_and_reset(self):
self._batch_data = _BatchDataWrapper() # clear old data
return _ContextManagerWrapperAsync(
return _ContextManagerAsync(
BatchCollectionAsync(
connection=self._connection,
consistency_level=self._consistency_level,
Expand Down Expand Up @@ -371,9 +371,9 @@ def stream(
concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests
made to Weaviate. If not provided, the default value is 1.
"""
if self._connection._weaviate_version.is_lower_than(1, 34, 0):
if self._connection._weaviate_version.is_lower_than(1, 36, 0):
raise WeaviateUnsupportedFeatureError(
"Server-side batching", str(self._connection._weaviate_version), "1.34.0"
"Server-side batching", str(self._connection._weaviate_version), "1.36.0"
)
self._batch_mode = _ServerSideBatching(
# concurrency=concurrency
Expand Down
16 changes: 4 additions & 12 deletions weaviate/collections/batch/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
_BatchMode,
_BgThreads,
_ClusterBatch,
_ServerSideBatching,
)
from weaviate.collections.batch.grpc_batch import _BatchGRPC
from weaviate.collections.classes.batch import (
Expand Down Expand Up @@ -54,9 +53,9 @@ def __init__(
connection: ConnectionSync,
consistency_level: Optional[ConsistencyLevel],
results: _BatchDataWrapper,
batch_mode: _BatchMode,
executor: ThreadPoolExecutor,
vectorizer_batching: bool,
batch_mode: Optional[_BatchMode] = None,
executor: Optional[ThreadPoolExecutor] = None,
vectorizer_batching: bool = False,
objects: Optional[ObjectsBatchRequest[BatchObject]] = None,
references: Optional[ReferencesBatchRequest[BatchReference]] = None,
) -> None:
Expand Down Expand Up @@ -108,8 +107,6 @@ def __init__(
self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1)
self.__stop = False

self.__batch_mode = batch_mode

@property
def number_errors(self) -> int:
"""Return the number of errors in the batch."""
Expand All @@ -123,12 +120,7 @@ def __all_threads_alive(self) -> bool:
)

def _start(self) -> None:
assert isinstance(self.__batch_mode, _ServerSideBatching), (
"Only server-side batching is supported in this mode"
)
self.__bg_threads = [
self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency)
]
self.__bg_threads = [self.__start_bg_threads() for _ in range(1)]
logger.info(
f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing"
)
Expand Down
7 changes: 4 additions & 3 deletions weaviate/collections/data/async_.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import uuid as uuid_package
from typing import Generic, List, Literal, Optional, Sequence, Union, overload
from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload

from weaviate.collections.batch.collection import _BatchCollectionWrapper
from weaviate.collections.batch.grpc_batch import _BatchGRPC
from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC
from weaviate.collections.batch.rest import _BatchREST
Expand Down Expand Up @@ -30,7 +29,6 @@ class _DataCollectionAsync(
__batch_delete: _BatchDeleteGRPC
__batch_grpc: _BatchGRPC
__batch_rest: _BatchREST
__batch: _BatchCollectionWrapper[Properties]

async def insert(
self,
Expand Down Expand Up @@ -81,3 +79,6 @@ class _DataCollectionAsync(
async def delete_many(
self, where: _Filters, *, verbose: bool = False, dry_run: bool = False
) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ...
async def ingest(
self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]]
) -> BatchObjectReturn: ...
Loading
Loading