diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index 1a635298c..8cb175884 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -4,6 +4,7 @@ from typing import Callable, Generator, List, Optional, Protocol, Tuple import pytest +import pytest_asyncio from _pytest.fixtures import SubRequest import weaviate @@ -119,6 +120,53 @@ def _factory( client_fixture.close() +class AsyncClientFactory(Protocol): + """Typing for fixture.""" + + async def __call__( + self, name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False + ) -> Tuple[weaviate.WeaviateAsyncClient, str]: + """Typing for fixture.""" + ... + + +@pytest_asyncio.fixture +async def async_client_factory(request: SubRequest): + name_fixtures: List[str] = [] + client_fixture: Optional[weaviate.WeaviateAsyncClient] = None + + async def _factory( + name: str = "", ports: Tuple[int, int] = (8080, 50051), multi_tenant: bool = False + ): + nonlocal client_fixture, name_fixtures # noqa: F824 + name_fixture = _sanitize_collection_name(request.node.name) + name + name_fixtures.append(name_fixture) + if client_fixture is None: + client_fixture = weaviate.use_async_with_local(grpc_port=ports[1], port=ports[0]) + await client_fixture.connect() + + if await client_fixture.collections.exists(name_fixture): + await client_fixture.collections.delete(name_fixture) + + await client_fixture.collections.create( + name=name_fixture, + properties=[ + Property(name="name", data_type=DataType.TEXT), + Property(name="age", data_type=DataType.INT), + ], + references=[ReferenceProperty(name="test", target_collection=name_fixture)], + multi_tenancy_config=Configure.multi_tenancy(multi_tenant), + vectorizer_config=Configure.Vectorizer.none(), + ) + return client_fixture, name_fixture + + try: + yield _factory + finally: + if client_fixture is not None: + await client_fixture.close() + + def test_add_objects_in_multiple_batches(client_factory: ClientFactory) -> None: client, name = client_factory() with client.batch.rate_limit(50) as batch: @@ -367,13 +415,13 @@ def test_add_ref_batch_with_tenant(client_factory: ClientFactory) -> None: [ lambda client: client.batch.dynamic(), lambda client: client.batch.fixed_size(), - # lambda client: client.batch.rate_limit(9999), + lambda client: client.batch.rate_limit(9999), lambda client: client.batch.experimental(concurrency=1), ], ids=[ "test_add_ten_thousand_data_objects_dynamic", "test_add_ten_thousand_data_objects_fixed_size", - # "test_add_ten_thousand_data_objects_rate_limit", + "test_add_ten_thousand_data_objects_rate_limit", "test_add_ten_thousand_data_objects_experimental", ], ) @@ -767,3 +815,34 @@ def test_references_with_to_uuids(client_factory: ClientFactory) -> None: assert len(client.batch.failed_references) == 0, client.batch.failed_references client.collections.delete(["target", "source"]) + + +@pytest.mark.asyncio +async def test_add_one_hundred_thousand_objects_async_client( + async_client_factory: AsyncClientFactory, +) -> None: + """Test adding one hundred thousand data objects.""" + client, name = await async_client_factory() + if client._connection._weaviate_version.is_lower_than(1, 36, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") + nr_objects = 100000 + import time + + start = time.time() + async with client.batch.experimental(concurrency=1) as batch: + for i in range(nr_objects): + await batch.add_object( + collection=name, + properties={"name": "test" + str(i)}, + ) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(client.batch.results.objs.errors) == 0 + assert len(client.batch.results.objs.all_responses) == nr_objects + assert len(client.batch.results.objs.uuids) == nr_objects + assert await client.collections.use(name).length() == nr_objects + assert client.batch.results.objs.has_errors is False + assert len(client.batch.failed_objects) == 0, [ + obj.message for obj in client.batch.failed_objects + ] + await client.collections.delete(name) diff --git a/integration/test_collection_batch.py b/integration/test_collection_batch.py index 72b3b7c53..ca0eb2116 100644 --- a/integration/test_collection_batch.py +++ b/integration/test_collection_batch.py @@ -1,10 +1,10 @@ import uuid from dataclasses import dataclass -from typing import Any, Generator, Optional, Protocol, Union +from typing import Any, Awaitable, Generator, Optional, Protocol, Union import pytest -from integration.conftest import CollectionFactory, CollectionFactoryGet +from integration.conftest import AsyncCollectionFactory, CollectionFactory, CollectionFactoryGet from weaviate.collections import Collection from weaviate.collections.classes.config import ( Configure, @@ -17,6 +17,8 @@ from weaviate.collections.classes.tenants import Tenant from weaviate.types import VECTORS +from weaviate.collections.collection.async_ import CollectionAsync + UUID = Union[str, uuid.UUID] @@ -55,11 +57,21 @@ def __call__(self, name: str = "", multi_tenancy: bool = False) -> Collection[An ... +class BatchCollectionAsync(Protocol): + """Typing for fixture.""" + + def __call__( + self, name: str = "", multi_tenancy: bool = False + ) -> Awaitable[CollectionAsync[Any, Any]]: + """Typing for fixture.""" + ... + + @pytest.fixture def batch_collection( collection_factory: CollectionFactory, ) -> Generator[BatchCollection, None, None]: - def _factory(name: str = "", multi_tenancy: bool = False) -> Collection[Any, Any]: + def _factory(name: str = "", multi_tenancy: bool = False): collection = collection_factory( name=name, vectorizer_config=Configure.Vectorizer.none(), @@ -78,6 +90,29 @@ def _factory(name: str = "", multi_tenancy: bool = False) -> Collection[Any, Any yield _factory +@pytest.fixture +def batch_collection_async( + async_collection_factory: AsyncCollectionFactory, +) -> Generator[BatchCollectionAsync, None, None]: + async def _factory(name: str = "", multi_tenancy: bool = False): + collection = await async_collection_factory( + name=name, + vectorizer_config=Configure.Vectorizer.none(), + properties=[ + Property(name="name", data_type=DataType.TEXT), + Property(name="age", data_type=DataType.INT), + ], + multi_tenancy_config=Configure.multi_tenancy(multi_tenancy), + ) + await collection.config.add_reference( + ReferenceProperty(name="test", target_collection=collection.name) + ) + + return collection + + yield _factory + + @pytest.mark.parametrize( "vector", [None, [1, 2, 3], MockNumpyTorch([1, 2, 3]), MockTensorFlow([1, 2, 3])], @@ -233,3 +268,30 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) - # above should not throw - depending on the autoschema config this might create an error or # not, so we do not check for errors here + + +@pytest.mark.asyncio +async def test_add_one_hundred_thousand_objects_async_collection( + batch_collection_async: BatchCollectionAsync, +) -> None: + """Test adding one hundred thousand data objects.""" + col = await batch_collection_async() + if col._connection._weaviate_version.is_lower_than(1, 36, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") + nr_objects = 100000 + import time + + start = time.time() + async with col.batch.experimental() as batch: + for i in range(nr_objects): + await batch.add_object( + properties={"name": "test" + str(i)}, + ) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(col.batch.results.objs.errors) == 0 + assert len(col.batch.results.objs.all_responses) == nr_objects + assert len(col.batch.results.objs.uuids) == nr_objects + assert await col.length() == nr_objects + assert col.batch.results.objs.has_errors is False + assert len(col.batch.failed_objects) == 0, [obj.message for obj in col.batch.failed_objects] diff --git a/weaviate/client.py b/weaviate/client.py index 8cf856c51..d7f9080f4 100644 --- a/weaviate/client.py +++ b/weaviate/client.py @@ -10,7 +10,7 @@ from .auth import AuthCredentials from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper +from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync from .collections.collections import _Collections, _CollectionsAsync from .config import AdditionalConfig from .connect import executor @@ -76,6 +76,7 @@ def __init__( ) self.alias = _AliasAsync(self._connection) self.backup = _BackupAsync(self._connection) + self.batch = _BatchClientWrapperAsync(self._connection) self.cluster = _ClusterAsync(self._connection) self.collections = _CollectionsAsync(self._connection) self.debug = _DebugAsync(self._connection) diff --git a/weaviate/client.pyi b/weaviate/client.pyi index 205a34b4e..9b32af15f 100644 --- a/weaviate/client.pyi +++ b/weaviate/client.pyi @@ -18,7 +18,7 @@ from weaviate.users.sync import _Users from .backup import _Backup, _BackupAsync from .cluster import _Cluster, _ClusterAsync -from .collections.batch.client import _BatchClientWrapper +from .collections.batch.client import _BatchClientWrapper, _BatchClientWrapperAsync from .debug import _Debug, _DebugAsync from .rbac import _Roles, _RolesAsync from .types import NUMBER @@ -29,6 +29,7 @@ class WeaviateAsyncClient(_WeaviateClientExecutor[ConnectionAsync]): _connection: ConnectionAsync alias: _AliasAsync backup: _BackupAsync + batch: _BatchClientWrapperAsync collections: _CollectionsAsync cluster: _ClusterAsync debug: _DebugAsync diff --git a/weaviate/collections/batch/async_.py b/weaviate/collections/batch/async_.py new file mode 100644 index 000000000..e8b1c3452 --- /dev/null +++ b/weaviate/collections/batch/async_.py @@ -0,0 +1,549 @@ +import asyncio +import time +import uuid as uuid_package +from typing import ( + Generator, + List, + Optional, + Set, + Union, +) + +from pydantic import ValidationError + +from weaviate.collections.batch.base import ( + GCP_STREAM_TIMEOUT, + ObjectsBatchRequest, + ReferencesBatchRequest, + _BatchDataWrapper, +) +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.classes.batch import ( + BatchObject, + BatchObjectReturn, + BatchReference, + BatchReferenceReturn, + ErrorObject, + ErrorReference, + Shard, +) +from weaviate.collections.classes.config import ConsistencyLevel +from weaviate.collections.classes.internal import ( + ReferenceInput, + ReferenceInputs, + ReferenceToMulti, +) +from weaviate.collections.classes.types import WeaviateProperties +from weaviate.connect.v4 import ConnectionAsync +from weaviate.exceptions import ( + WeaviateBatchStreamError, + WeaviateBatchValidationError, + WeaviateGRPCUnavailableError, + WeaviateStartUpError, +) +from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 +from weaviate.types import UUID, VECTORS + + +class _BgTasks: + def __init__( + self, send: asyncio.Task[None], recv: asyncio.Task[None], loop: asyncio.Task[None] + ) -> None: + self.send = send + self.recv = recv + self.loop = loop + + def all_alive(self) -> bool: + return all([not self.send.done(), not self.recv.done(), not self.loop.done()]) + + +class _BatchBaseAsync: + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + results: _BatchDataWrapper, + objects: Optional[ObjectsBatchRequest[BatchObject]] = None, + references: Optional[ReferencesBatchRequest[BatchReference]] = None, + ) -> None: + self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]() + self.__batch_references = references or ReferencesBatchRequest[BatchReference]() + + self.__connection = connection + self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd() + self.__stream_start: Optional[float] = None + self.__is_renewing_stream = asyncio.Event() + self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM + self.__batch_size = 100 + + self.__batch_grpc = _BatchGRPC( + connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size + ) + self.__stream = self.__connection.grpc_batch_stream() + + # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet + self.__uuid_lookup: Set[str] = set() + + # we do not want that users can access the results directly as they are not thread-safe + self.__results_for_wrapper_backup = results + self.__results_for_wrapper = _BatchDataWrapper() + + self.__objs_count = 0 + self.__refs_count = 0 + + self.__is_oom = asyncio.Event() + self.__is_shutting_down = asyncio.Event() + self.__is_shutdown = asyncio.Event() + + self.__objs_cache: dict[str, BatchObject] = {} + self.__refs_cache: dict[str, BatchReference] = {} + + self.__inflight_objs: set[str] = set() + self.__inflight_refs: set[str] = set() + + # maxsize=1 so that __send does not run faster than generator for __recv + # thereby using too much buffer in case of server-side shutdown + self.__reqs: asyncio.Queue[Optional[batch_pb2.BatchStreamRequest]] = asyncio.Queue( + maxsize=1 + ) + + self.__stop = False + self.__shutdown_send_task = asyncio.Event() + self.__bg_exception: Optional[Exception] = None + self.__bg_tasks: Optional[_BgTasks] = None + + @property + def number_errors(self) -> int: + """Return the number of errors in the batch.""" + return len(self.__results_for_wrapper.failed_objects) + len( + self.__results_for_wrapper.failed_references + ) + + def __all_tasks_alive(self) -> bool: + return self.__bg_tasks is not None and self.__bg_tasks.all_alive() + + async def _wait(self): + assert self.__bg_tasks is not None + await asyncio.gather( + self.__bg_tasks.send, + self.__bg_tasks.recv, + self.__bg_tasks.loop, + ) + # copy the results to the public results + self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results + self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects + self.__results_for_wrapper_backup.failed_references = ( + self.__results_for_wrapper.failed_references + ) + self.__results_for_wrapper_backup.imported_shards = ( + self.__results_for_wrapper.imported_shards + ) + + async def _start(self): + async def send_wrapper() -> None: + try: + await self.__send() + logger.warning("exited batch send thread") + except Exception as e: + logger.error(e) + self.__bg_exception = e + + async def loop_wrapper() -> None: + try: + await self.__loop() + logger.warning("exited batch loop thread") + except Exception as e: + logger.error(e) + self.__bg_exception = e + + async def recv_wrapper() -> None: + socket_hung_up = False + try: + await self.__recv() + logger.warning("exited batch receive thread") + except Exception as e: + if isinstance(e, WeaviateBatchStreamError) and ( + "Socket closed" in e.message or "context canceled" in e.message + ): + socket_hung_up = True + else: + logger.error(e) + self.__bg_exception = e + if socket_hung_up: + # this happens during ungraceful shutdown of the coordinator + # lets restart the stream and add the cached objects again + logger.warning("Stream closed unexpectedly, restarting...") + await self.__reconnect() + # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now + self.__is_shutting_down.clear() + await self.__batch_objects.aprepend(list(self.__objs_cache.values())) + await self.__batch_references.aprepend(list(self.__refs_cache.values())) + # start a new stream with a newly reconnected channel + return await recv_wrapper() + + self.__bg_tasks = _BgTasks( + send=asyncio.create_task(send_wrapper()), + recv=asyncio.create_task(recv_wrapper()), + loop=asyncio.create_task(loop_wrapper()), + ) + + async def _shutdown(self) -> None: + self.__stop = True + + async def __loop(self) -> None: + refresh_time: float = 0.01 + while self.__bg_exception is None: + if len(self.__batch_objects) + len(self.__batch_references) > 0: + self._batch_send = True + start = time.time() + while (len_o := len(self.__batch_objects)) + ( + len_r := len(self.__batch_references) + ) < self.__batch_size: + # wait for more objects to be added up to the batch size + await asyncio.sleep(0.01) + if self.__shutdown_send_task.is_set(): + logger.warning("Tasks were shutdown, exiting batch send loop") + # shutdown was requested, exit early + await self.__reqs.put(None) + return + if time.time() - start >= 1 and ( + len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) + ): + # no new objects were added in the last second, exit the loop + break + + objs = self.__batch_objects.pop_items(self.__batch_size) + refs = self.__batch_references.pop_items( + self.__batch_size - len(objs), + uuid_lookup=self.__uuid_lookup, + ) + + for req in self.__generate_stream_requests(objs, refs): + await self.__reqs.put(req) + elif self.__stop: + # we are done, send the sentinel into our queue to be consumed by the batch sender + await self.__reqs.put(None) # signal the end of the stream + logger.warning("Batching finished, sent stop signal to batch stream") + return + await asyncio.sleep(refresh_time) + + def __generate_stream_requests( + self, + objects: List[BatchObject], + references: List[BatchReference], + ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + per_object_overhead = 4 # extra overhead bytes per object in the request + + def request_maker(): + return batch_pb2.BatchStreamRequest() + + request = request_maker() + total_size = request.ByteSize() + + inflight_objs = set() + inflight_refs = set() + for object_ in objects: + obj = self.__batch_grpc.grpc_object(object_._to_internal()) + obj_size = obj.ByteSize() + per_object_overhead + + if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: + self.__inflight_objs.update(inflight_objs) + self.__inflight_refs.update(inflight_refs) + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.objects.values.append(obj) + total_size += obj_size + inflight_objs.add(obj.uuid) + + for reference in references: + ref = self.__batch_grpc.grpc_reference(reference._to_internal()) + ref_size = ref.ByteSize() + per_object_overhead + + if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: + self.__inflight_objs.update(inflight_objs) + self.__inflight_refs.update(inflight_refs) + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.references.values.append(ref) + total_size += ref_size + inflight_refs.add(reference._to_beacon()) + + if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: + self.__inflight_objs.update(inflight_objs) + self.__inflight_refs.update(inflight_refs) + yield request + + async def __end_stream(self): + await self.__connection.grpc_batch_stream_write( + self.__stream, + batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()), + ) + await self.__stream.done_writing() + + async def __send(self): + await self.__connection.grpc_batch_stream_write( + stream=self.__stream, + request=batch_pb2.BatchStreamRequest( + start=batch_pb2.BatchStreamRequest.Start( + consistency_level=self.__batch_grpc._consistency_level, + ), + ), + ) + while self.__bg_exception is None: + if self.__is_gcp_on_wcd: + assert self.__stream_start is not None + if time.time() - self.__stream_start > GCP_STREAM_TIMEOUT: + logger.warning( + "GCP connections have a maximum lifetime. Re-establishing the batch stream to avoid timeout errors." + ) + self.__is_renewing_stream.set() + await self.__end_stream() + return + try: + req = await asyncio.wait_for(self.__reqs.get(), timeout=1) + except asyncio.TimeoutError: + continue + if req is not None: + await self.__connection.grpc_batch_stream_write(self.__stream, req) + continue + if self.__stop and not ( + self.__is_shutting_down.is_set() or self.__is_shutdown.is_set() + ): + logger.warning("Batching finished, closing the client-side of the stream") + await self.__end_stream() + return + if self.__is_shutting_down.is_set(): + logger.warning("Server shutting down, closing the client-side of the stream") + await self.__stream.done_writing() + return + if self.__is_oom.is_set(): + logger.warning("Server out-of-memory, closing the client-side of the stream") + await self.__stream.done_writing() + return + logger.warning("Received sentinel, but not stopping, continuing...") + + async def __recv(self) -> None: + while self.__bg_exception is None: + message = await self.__connection.grpc_batch_stream_read(self.__stream) + if not isinstance(message, batch_pb2.BatchStreamReply): + logger.warning("Server closed the stream from its side, shutting down batch") + return + if message.HasField("started"): + logger.warning("Batch stream started successfully") + if message.HasField("backoff"): + if ( + message.backoff.batch_size != self.__batch_size + and not self.__is_shutting_down.is_set() + and not self.__is_shutdown.is_set() + and not self.__stop + ): + self.__batch_size = message.backoff.batch_size + logger.warning( + f"Updated batch size to {self.__batch_size} as per server request" + ) + if message.HasField("acks"): + self.__inflight_objs.difference_update(message.acks.uuids) + self.__uuid_lookup.difference_update(message.acks.uuids) + self.__inflight_refs.difference_update(message.acks.beacons) + if message.HasField("results"): + result_objs = BatchObjectReturn() + result_refs = BatchReferenceReturn() + failed_objs: List[ErrorObject] = [] + failed_refs: List[ErrorReference] = [] + for error in message.results.errors: + if error.HasField("uuid"): + try: + cached = self.__objs_cache.pop(error.uuid) + except KeyError: + continue + err = ErrorObject( + message=error.error, + object_=cached, + ) + result_objs += BatchObjectReturn( + _all_responses=[err], + errors={cached.index: err}, + ) + failed_objs.append(err) + logger.warning( + { + "error": error.error, + "object": error.uuid, + "action": "use {client,collection}.batch.failed_objects to access this error", + } + ) + if error.HasField("beacon"): + try: + cached = self.__refs_cache.pop(error.beacon) + except KeyError: + continue + err = ErrorReference( + message=error.error, + reference=cached, + ) + result_refs += BatchReferenceReturn( + errors={cached.index: err}, + ) + failed_refs.append(err) + logger.warning( + { + "error": error.error, + "reference": error.beacon, + "action": "use {client,collection}.batch.failed_references to access this error", + } + ) + for success in message.results.successes: + if success.HasField("uuid"): + try: + cached = self.__objs_cache.pop(success.uuid) + except KeyError: + continue + uuid = uuid_package.UUID(success.uuid) + result_objs += BatchObjectReturn( + _all_responses=[uuid], + uuids={cached.index: uuid}, + ) + if success.HasField("beacon"): + try: + self.__refs_cache.pop(success.beacon, None) + except KeyError: + continue + self.__results_for_wrapper.results.objs += result_objs + self.__results_for_wrapper.results.refs += result_refs + self.__results_for_wrapper.failed_objects.extend(failed_objs) + self.__results_for_wrapper.failed_references.extend(failed_refs) + if message.HasField("out_of_memory"): + logger.warning( + "Server reported out-of-memory error. Batching will wait at most 10 minutes for the server to scale-up. If the server does not recover within this time, the batch will terminate with an error." + ) + self.__is_oom.set() + await self.__batch_objects.aprepend( + [self.__objs_cache[uuid] for uuid in message.out_of_memory.uuids] + ) + await self.__batch_references.aprepend( + [self.__refs_cache[beacon] for beacon in message.out_of_memory.beacons] + ) + if message.HasField("shutting_down"): + logger.warning( + "Received shutting down message from server, pausing sending until stream is re-established" + ) + self.__is_shutting_down.set() + if message.HasField("shutdown"): + logger.warning("Received shutdown finished message from server") + self.__is_shutdown.set() + self.__is_shutting_down.clear() + await self.__reconnect() + + # restart the stream if we were shutdown by the node we were connected to + if self.__is_shutdown.is_set(): + logger.warning("Restarting batch recv after shutdown...") + self.__is_shutdown.clear() + return await self.__recv() + + async def __reconnect(self, retry: int = 0) -> None: + try: + logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") + self.__connection.close("sync") + await self.__connection.connect(force=True) + logger.warning("Reconnected successfully") + self.__stream = self.__connection.grpc_batch_stream() + except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: + if retry < 5: + await asyncio.sleep(2**retry) + await self.__reconnect(retry + 1) + else: + logger.error("Failed to reconnect after 5 attempts") + self.__bg_thread_exception = e + + async def flush(self) -> None: + """Flush the batch queue and wait for all requests to be finished.""" + # bg thread is sending objs+refs automatically, so simply wait for everything to be done + while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: + await asyncio.sleep(0.01) + + async def _add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[str] = None, + ) -> UUID: + self.__check_bg_tasks_alive() + try: + batch_object = BatchObject( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant, + index=self.__objs_count, + ) + self.__results_for_wrapper.imported_shards.add( + Shard(collection=collection, tenant=tenant) + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + uuid = str(batch_object.uuid) + self.__uuid_lookup.add(uuid) + await self.__batch_objects.aadd(batch_object) + self.__objs_cache[uuid] = batch_object + self.__objs_count += 1 + + while len(self.__inflight_objs) >= self.__batch_size: + self.__check_bg_tasks_alive() + await asyncio.sleep(0.01) + + assert batch_object.uuid is not None + return batch_object.uuid + + async def _add_reference( + self, + from_object_uuid: UUID, + from_object_collection: str, + from_property_name: str, + to: ReferenceInput, + tenant: Optional[str] = None, + ) -> None: + self.__check_bg_tasks_alive() + if isinstance(to, ReferenceToMulti): + to_strs: Union[List[str], List[UUID]] = to.uuids_str + elif isinstance(to, str) or isinstance(to, uuid_package.UUID): + to_strs = [to] + else: + to_strs = list(to) + + for uid in to_strs: + try: + batch_reference = BatchReference( + from_object_collection=from_object_collection, + from_object_uuid=from_object_uuid, + from_property_name=from_property_name, + to_object_collection=( + to.target_collection if isinstance(to, ReferenceToMulti) else None + ), + to_object_uuid=uid, + tenant=tenant, + index=self.__refs_count, + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + await self.__batch_references.aadd(batch_reference) + self.__refs_cache[batch_reference._to_beacon()] = batch_reference + self.__refs_count += 1 + while len(self.__inflight_refs) >= self.__batch_size * 2: + self.__check_bg_tasks_alive() + await asyncio.sleep(0.01) + + def __check_bg_tasks_alive(self) -> None: + if self.__all_tasks_alive(): + return + + raise self.__bg_exception or Exception("Batch tasks died unexpectedly") diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 47bb5e1da..d56c2879d 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -1,3 +1,4 @@ +import asyncio import contextvars import functools import math @@ -10,8 +11,7 @@ from concurrent.futures import ThreadPoolExecutor from copy import copy from dataclasses import dataclass, field -from queue import Queue -from typing import Any, Dict, Generator, Generic, List, Optional, Set, TypeVar, Union, cast +from typing import Any, Dict, Generic, List, Optional, Set, TypeVar, Union, cast from pydantic import ValidationError from typing_extensions import TypeAlias @@ -37,16 +37,12 @@ ) from weaviate.collections.classes.types import WeaviateProperties from weaviate.connect import executor -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.exceptions import ( - WeaviateBatchFailedToReestablishStreamError, - WeaviateBatchStreamError, + EmptyResponseException, WeaviateBatchValidationError, - WeaviateGRPCUnavailableError, - WeaviateStartUpError, ) from weaviate.logger import logger -from weaviate.proto.v1 import batch_pb2 from weaviate.types import UUID, VECTORS from weaviate.util import _decode_json_response_dict from weaviate.warnings import _Warnings @@ -75,24 +71,42 @@ class BatchRequest(ABC, Generic[TBatchInput, TBatchReturn]): def __init__(self) -> None: self._items: List[TBatchInput] = [] self._lock = threading.Lock() + self._alock = asyncio.Lock() def __len__(self) -> int: - return len(self._items) + with self._lock: + return len(self._items) + + async def alen(self) -> int: + """Asynchronously get the length of the BatchRequest.""" + async with self._alock: + return len(self._items) def add(self, item: TBatchInput) -> None: """Add an item to the BatchRequest.""" - self._lock.acquire() - self._items.append(item) - self._lock.release() + with self._lock: + self._items.append(item) + + async def aadd(self, item: TBatchInput) -> None: + """Asynchronously add an item to the BatchRequest.""" + async with self._alock: + self._items.append(item) def prepend(self, item: List[TBatchInput]) -> None: """Add items to the front of the BatchRequest. This is intended to be used when objects should be retries, eg. after a temporary error. """ - self._lock.acquire() - self._items = item + self._items - self._lock.release() + with self._lock: + self._items = item + self._items + + async def aprepend(self, item: List[TBatchInput]) -> None: + """Asynchronously add items to the front of the BatchRequest. + + This is intended to be used when objects should be retries, eg. after a temporary error. + """ + async with self._alock: + self._items = item + self._items Ref = TypeVar("Ref", bound=BatchReference) @@ -101,15 +115,9 @@ def prepend(self, item: List[TBatchInput]) -> None: class ReferencesBatchRequest(BatchRequest[Ref, BatchReferenceReturn]): """Collect Weaviate-object references to add them in one request to Weaviate.""" - def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: - """Pop the given number of items from the BatchRequest queue. - - Returns: - A list of items from the BatchRequest. - """ + def __pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: ret: List[Ref] = [] i = 0 - self._lock.acquire() while len(ret) < pop_amount and len(self._items) > 0 and i < len(self._items): if self._items[i].from_object_uuid not in uuid_lookup and ( self._items[i].to_object_uuid is None @@ -118,19 +126,48 @@ def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: ret.append(self._items.pop(i)) else: i += 1 - self._lock.release() return ret + def pop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: + """Pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + with self._lock: + return self.__pop_items(pop_amount, uuid_lookup) + + async def apop_items(self, pop_amount: int, uuid_lookup: Set[str]) -> List[Ref]: + """Asynchronously pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + async with self._alock: + return self.__pop_items(pop_amount, uuid_lookup) + + def __head(self) -> Optional[Ref]: + if len(self._items) > 0: + return self._items[0] + return None + def head(self) -> Optional[Ref]: """Get the first item from the BatchRequest queue without removing it. Returns: The first item from the BatchRequest or None if the queue is empty. """ - self._lock.acquire() - item = self._items[0] if len(self._items) > 0 else None - self._lock.release() - return item + with self._lock: + return self.__head() + + async def ahead(self) -> Optional[Ref]: + """Asynchronously get the first item from the BatchRequest queue without removing it. + + Returns: + The first item from the BatchRequest or None if the queue is empty. + """ + async with self._alock: + return self.__head() Obj = TypeVar("Obj", bound=BatchObject) @@ -139,33 +176,55 @@ def head(self) -> Optional[Ref]: class ObjectsBatchRequest(Generic[Obj], BatchRequest[Obj, BatchObjectReturn]): """Collect objects for one batch request to weaviate.""" - def pop_items(self, pop_amount: int) -> List[Obj]: - """Pop the given number of items from the BatchRequest queue. - - Returns: - A list of items from the BatchRequest. - """ - self._lock.acquire() + def __pop_items(self, pop_amount: int) -> List[Obj]: if pop_amount >= len(self._items): ret = copy(self._items) self._items.clear() else: ret = copy(self._items[:pop_amount]) self._items = self._items[pop_amount:] - - self._lock.release() return ret + def pop_items(self, pop_amount: int) -> List[Obj]: + """Pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + with self._lock: + return self.__pop_items(pop_amount) + + async def apop_items(self, pop_amount: int) -> List[Obj]: + """Asynchronously pop the given number of items from the BatchRequest queue. + + Returns: + A list of items from the BatchRequest. + """ + async with self._alock: + return self.__pop_items(pop_amount) + + def __head(self) -> Optional[Obj]: + if len(self._items) > 0: + return self._items[0] + return None + def head(self) -> Optional[Obj]: """Get the first item from the BatchRequest queue without removing it. Returns: The first item from the BatchRequest or None if the queue is empty. """ - self._lock.acquire() - item = self._items[0] if len(self._items) > 0 else None - self._lock.release() - return item + with self._lock: + return self.__head() + + async def ahead(self) -> Optional[Obj]: + """Asynchronously get the first item from the BatchRequest queue without removing it. + + Returns: + The first item from the BatchRequest or None if the queue is empty. + """ + async with self._alock: + return self.__head() @dataclass @@ -287,7 +346,7 @@ def __init__( self.__uuid_lookup_lock = threading.Lock() self.__results_lock = threading.Lock() - self.__bg_thread = self.__start_bg_threads() + self.__bg_threads = self.__start_bg_threads() self.__bg_thread_exception: Optional[Exception] = None @property @@ -306,7 +365,7 @@ def _shutdown(self) -> None: # we are done, shut bg threads down and end the event loop self.__shut_background_thread_down.set() - while self.__bg_thread.is_alive(): + while self.__bg_threads.is_alive(): time.sleep(0.01) # copy the results to the public results @@ -714,7 +773,7 @@ def flush(self) -> None: or len(self.__batch_references) > 0 ): time.sleep(0.01) - self.__check_bg_thread_alive() + self.__check_bg_threads_alive() def _add_object( self, @@ -725,7 +784,7 @@ def _add_object( vector: Optional[VECTORS] = None, tenant: Optional[str] = None, ) -> UUID: - self.__check_bg_thread_alive() + self.__check_bg_threads_alive() try: batch_object = BatchObject( collection=collection, @@ -751,7 +810,7 @@ def _add_object( self.__recommended_num_objects == 0 or len(self.__batch_objects) >= self.__recommended_num_objects * 2 ): - self.__check_bg_thread_alive() + self.__check_bg_threads_alive() time.sleep(0.01) assert batch_object.uuid is not None @@ -765,7 +824,7 @@ def _add_reference( to: ReferenceInput, tenant: Optional[str] = None, ) -> None: - self.__check_bg_thread_alive() + self.__check_bg_threads_alive() if isinstance(to, ReferenceToMulti): to_strs: Union[List[str], List[UUID]] = to.uuids_str elif isinstance(to, str) or isinstance(to, uuid_package.UUID): @@ -794,640 +853,48 @@ def _add_reference( # block if queue gets too long or weaviate is overloaded while self.__recommended_num_objects == 0: time.sleep(0.01) # block if weaviate is overloaded, also do not send any refs - self.__check_bg_thread_alive() + self.__check_bg_threads_alive() - def __check_bg_thread_alive(self) -> None: - if self.__bg_thread.is_alive(): + def __check_bg_threads_alive(self) -> None: + if self.__bg_threads.is_alive(): return raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly") class _BgThreads: - def __init__(self, send: threading.Thread, recv: threading.Thread): - self.send = send + def __init__(self, loop: threading.Thread, recv: threading.Thread): + self.loop = loop self.recv = recv self.__started_recv = False - self.__started_send = False + self.__started_loop = False def start_recv(self) -> None: if not self.__started_recv: self.recv.start() self.__started_recv = True - def start_send(self) -> None: - if not self.__started_send: - self.send.start() - self.__started_send = True + def start_loop(self) -> None: + if not self.__started_loop: + self.loop.start() + self.__started_loop = True def is_alive(self) -> bool: """Check if the background threads are still alive.""" - return self.send_alive() or self.recv_alive() + return self.loop_alive() or self.recv_alive() - def send_alive(self) -> bool: - """Check if the send background thread is still alive.""" - return self.send.is_alive() + def loop_alive(self) -> bool: + """Check if the loop background thread is still alive.""" + return self.loop.is_alive() def recv_alive(self) -> bool: """Check if the recv background thread is still alive.""" return self.recv.is_alive() - -class _BatchBaseNew: - def __init__( - self, - connection: ConnectionSync, - consistency_level: Optional[ConsistencyLevel], - results: _BatchDataWrapper, - batch_mode: _BatchMode, - executor: ThreadPoolExecutor, - vectorizer_batching: bool, - objects: Optional[ObjectsBatchRequest[BatchObject]] = None, - references: Optional[ReferencesBatchRequest[BatchReference]] = None, - ) -> None: - self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]() - self.__batch_references = references or ReferencesBatchRequest[BatchReference]() - - self.__connection = connection - self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd() - self.__stream_start: Optional[float] = None - self.__is_renewing_stream = threading.Event() - - self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM - self.__batch_size = 100 - - self.__batch_grpc = _BatchGRPC( - connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size - ) - self.__cluster = _ClusterBatch(self.__connection) - self.__number_of_nodes = self.__cluster.get_number_of_nodes() - - # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet - self.__uuid_lookup: Set[str] = set() - - # we do not want that users can access the results directly as they are not thread-safe - self.__results_for_wrapper_backup = results - self.__results_for_wrapper = _BatchDataWrapper() - - self.__objs_count = 0 - self.__refs_count = 0 - - self.__uuid_lookup_lock = threading.Lock() - self.__results_lock = threading.Lock() - - self.__bg_thread_exception: Optional[Exception] = None - self.__is_oom = threading.Event() - self.__is_shutting_down = threading.Event() - self.__is_shutdown = threading.Event() - - self.__objs_cache_lock = threading.Lock() - self.__refs_cache_lock = threading.Lock() - self.__objs_cache: dict[str, BatchObject] = {} - self.__refs_cache: dict[str, BatchReference] = {} - - self.__acks_lock = threading.Lock() - self.__inflight_objs: set[str] = set() - self.__inflight_refs: set[str] = set() - - # maxsize=1 so that __batch_send does not run faster than generator for __batch_recv - # thereby using too much buffer in case of server-side shutdown - self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) - - self.__stop = False - - self.__batch_mode = batch_mode - - self.__total = 0 - - @property - def number_errors(self) -> int: - """Return the number of errors in the batch.""" - return len(self.__results_for_wrapper.failed_objects) + len( - self.__results_for_wrapper.failed_references - ) - - def __all_threads_alive(self) -> bool: - return self.__bg_threads is not None and all( - thread.is_alive() for thread in self.__bg_threads - ) - - def __any_threads_alive(self) -> bool: - return self.__bg_threads is not None and any( - thread.is_alive() for thread in self.__bg_threads - ) - - def _start(self) -> None: - assert isinstance(self.__batch_mode, _ServerSideBatching), ( - "Only server-side batching is supported in this mode" - ) - self.__bg_threads = [ - self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) - ] - logger.warning( - f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" - ) - now = time.time() - while not self.__all_threads_alive(): - # wait for the stream to be started by __batch_stream - time.sleep(0.01) - if time.time() - now > 60: - raise WeaviateBatchStreamError( - "Batch stream was not started within 60 seconds. Please check your connection." - ) - - def _shutdown(self) -> None: - # Shutdown the current batch and wait for all requests to be finished - self.flush() - self.__stop = True - - # we are done, wait for bg threads to finish - # self.__batch_stream will set the shutdown event when it receives - # the stop message from the server - while self.__any_threads_alive(): - time.sleep(0.05) - logger.warning("Send & receive threads finished.") - - # copy the results to the public results - self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results - self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects - self.__results_for_wrapper_backup.failed_references = ( - self.__results_for_wrapper.failed_references - ) - self.__results_for_wrapper_backup.imported_shards = ( - self.__results_for_wrapper.imported_shards - ) - - def __batch_send(self) -> None: - refresh_time: float = 0.01 - while ( - self.__shut_background_thread_down is not None - and not self.__shut_background_thread_down.is_set() - ): - if len(self.__batch_objects) + len(self.__batch_references) > 0: - self._batch_send = True - start = time.time() - while (len_o := len(self.__batch_objects)) + ( - len_r := len(self.__batch_references) - ) < self.__batch_size: - # wait for more objects to be added up to the batch size - time.sleep(0.01) - if ( - self.__shut_background_thread_down is not None - and self.__shut_background_thread_down.is_set() - ): - logger.warning("Threads were shutdown, exiting batch send loop") - # shutdown was requested, exit early - self.__reqs.put(None) - return - if time.time() - start >= 1 and ( - len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) - ): - # no new objects were added in the last second, exit the loop - break - - objs = self.__batch_objects.pop_items(self.__batch_size) - refs = self.__batch_references.pop_items( - self.__batch_size - len(objs), - uuid_lookup=self.__uuid_lookup, - ) - with self.__uuid_lookup_lock: - self.__uuid_lookup.difference_update(obj.uuid for obj in objs) - - for req in self.__generate_stream_requests(objs, refs): - logged = False - start = time.time() - while ( - self.__is_oom.is_set() - or self.__is_shutting_down.is_set() - or self.__is_shutdown.is_set() - ): - # if we were shutdown by the node we were connected to, we need to wait for the stream to be restarted - # so that the connection is refreshed to a new node where the objects can be accepted - # otherwise, we wait until the stream has been started by __batch_stream to send the first batch - if not logged: - logger.warning("Waiting for stream to be re-established...") - logged = True - # put sentinel into our queue to signal the end of the current stream - self.__reqs.put(None) - time.sleep(1) - if time.time() - start > 300: - raise WeaviateBatchFailedToReestablishStreamError( - "Batch stream was not re-established within 5 minutes. Terminating batch." - ) - if logged: - logger.warning("Stream re-established, resuming sending batches") - self.__reqs.put(req) - elif self.__stop: - # we are done, send the sentinel into our queue to be consumed by the batch sender - self.__reqs.put(None) # signal the end of the stream - logger.warning("Batching finished, sent stop signal to batch stream") - return - time.sleep(refresh_time) - - def __beacon(self, ref: batch_pb2.BatchReference) -> str: - return f"weaviate://localhost/{ref.from_collection}{f'#{ref.tenant}' if ref.tenant != '' else ''}/{ref.from_uuid}#{ref.name}->/{ref.to_collection}/{ref.to_uuid}" - - def __generate_stream_requests( - self, - objects: List[BatchObject], - references: List[BatchReference], - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - per_object_overhead = 4 # extra overhead bytes per object in the request - - def request_maker(): - return batch_pb2.BatchStreamRequest() - - request = request_maker() - total_size = request.ByteSize() - - inflight_objs = set() - inflight_refs = set() - for object_ in objects: - obj = self.__batch_grpc.grpc_object(object_._to_internal()) - obj_size = obj.ByteSize() + per_object_overhead - - if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: - yield request - request = request_maker() - total_size = request.ByteSize() - - request.data.objects.values.append(obj) - total_size += obj_size - if self.__connection._weaviate_version.is_at_least(1, 35, 0): - inflight_objs.add(obj.uuid) - - for reference in references: - ref = self.__batch_grpc.grpc_reference(reference._to_internal()) - ref_size = ref.ByteSize() + per_object_overhead - - if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: - yield request - request = request_maker() - total_size = request.ByteSize() - - request.data.references.values.append(ref) - total_size += ref_size - if self.__connection._weaviate_version.is_at_least(1, 35, 0): - inflight_refs.add(reference._to_beacon()) - - with self.__acks_lock: - self.__inflight_objs.update(inflight_objs) - self.__inflight_refs.update(inflight_refs) - - if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: - yield request - - def __generate_stream_requests_for_grpc( - self, - ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: - yield batch_pb2.BatchStreamRequest( - start=batch_pb2.BatchStreamRequest.Start( - consistency_level=self.__batch_grpc._consistency_level, - ), - ) - while ( - self.__shut_background_thread_down is not None - and not self.__shut_background_thread_down.is_set() - ): - if self.__is_gcp_on_wcd: - assert self.__stream_start is not None - if time.time() - self.__stream_start > GCP_STREAM_TIMEOUT: - logger.warning( - "GCP connections have a maximum lifetime. Re-establishing the batch stream to avoid timeout errors." - ) - yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) - self.__is_renewing_stream.set() - return - req = self.__reqs.get() - if req is not None: - self.__total += len(req.data.objects.values) + len(req.data.references.values) - yield req - continue - if self.__stop and not ( - self.__is_shutting_down.is_set() or self.__is_shutdown.is_set() - ): - logger.warning("Batching finished, closing the client-side of the stream") - yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) - return - if self.__is_shutting_down.is_set(): - logger.warning("Server shutting down, closing the client-side of the stream") - return - if self.__is_oom.is_set(): - logger.warning("Server out-of-memory, closing the client-side of the stream") - return - logger.warning("Received sentinel, but not stopping, continuing...") - - def __batch_recv(self) -> None: - self.__stream_start = time.time() - for message in self.__batch_grpc.stream( - connection=self.__connection, - requests=self.__generate_stream_requests_for_grpc(), - ): - if message.HasField("started"): - logger.warning("Batch stream started successfully") - for threads in self.__bg_threads: - threads.start_send() - if message.HasField("backoff"): - if ( - message.backoff.batch_size != self.__batch_size - and not self.__is_shutting_down.is_set() - and not self.__is_shutdown.is_set() - and not self.__stop - ): - self.__batch_size = message.backoff.batch_size - logger.warning( - f"Updated batch size to {self.__batch_size} as per server request" - ) - if message.HasField("acks"): - with self.__acks_lock: - self.__inflight_objs.difference_update(message.acks.uuids) - self.__uuid_lookup.difference_update(message.acks.uuids) - self.__inflight_refs.difference_update(message.acks.beacons) - if message.HasField("results"): - result_objs = BatchObjectReturn() - result_refs = BatchReferenceReturn() - failed_objs: List[ErrorObject] = [] - failed_refs: List[ErrorReference] = [] - for error in message.results.errors: - if error.HasField("uuid"): - try: - cached = self.__objs_cache.pop(error.uuid) - except KeyError: - continue - err = ErrorObject( - message=error.error, - object_=cached, - ) - result_objs += BatchObjectReturn( - _all_responses=[err], - errors={cached.index: err}, - ) - failed_objs.append(err) - logger.warning( - { - "error": error.error, - "object": error.uuid, - "action": "use {client,collection}.batch.failed_objects to access this error", - } - ) - if error.HasField("beacon"): - try: - cached = self.__refs_cache.pop(error.beacon) - except KeyError: - continue - err = ErrorReference( - message=error.error, - reference=error.beacon, # pyright: ignore - ) - failed_refs.append(err) - result_refs += BatchReferenceReturn( - errors={cached.index: err}, - ) - logger.warning( - { - "error": error.error, - "reference": error.beacon, - "action": "use {client,collection}.batch.failed_references to access this error", - } - ) - for success in message.results.successes: - if success.HasField("uuid"): - try: - cached = self.__objs_cache.pop(success.uuid) - except KeyError: - continue - uuid = uuid_package.UUID(success.uuid) - result_objs += BatchObjectReturn( - _all_responses=[uuid], - uuids={cached.index: uuid}, - ) - if success.HasField("beacon"): - try: - self.__refs_cache.pop(success.beacon, None) - except KeyError: - continue - with self.__results_lock: - self.__results_for_wrapper.results.objs += result_objs - self.__results_for_wrapper.results.refs += result_refs - self.__results_for_wrapper.failed_objects.extend(failed_objs) - self.__results_for_wrapper.failed_references.extend(failed_refs) - if message.HasField("out_of_memory"): - logger.warning( - "Server reported out-of-memory error. Batching will wait at most 10 minutes for the server to scale-up. If the server does not recover within this time, the batch will terminate with an error." - ) - self.__is_oom.set() - self.__batch_objects.prepend( - [self.__objs_cache[uuid] for uuid in message.out_of_memory.uuids] - ) - self.__batch_references.prepend( - [self.__refs_cache[beacon] for beacon in message.out_of_memory.beacons] - ) - if message.HasField("shutting_down"): - logger.warning( - "Received shutting down message from server, pausing sending until stream is re-established" - ) - self.__is_shutting_down.set() - self.__is_oom.clear() - if message.HasField("shutdown"): - logger.warning("Received shutdown finished message from server") - self.__is_shutdown.set() - self.__is_shutting_down.clear() - self.__reconnect() - - if self.__is_shutdown.is_set(): - # restart the stream if we were shutdown by the node we were connected to - logger.warning("Restarting batch recv after shutdown...") - self.__is_shutdown.clear() - return self.__batch_recv() - elif self.__is_renewing_stream.is_set(): - # restart the stream if we are renewing it (GCP connections have a max lifetime) - logger.warning("Restarting batch recv after renewing stream...") - self.__is_renewing_stream.clear() - return self.__batch_recv() - else: - logger.warning("Server closed the stream from its side, shutting down batch") - return - - def __reconnect(self, retry: int = 0) -> None: - if self.__consistency_level == ConsistencyLevel.ALL or self.__number_of_nodes == 1: - # check that all nodes are available before reconnecting - up_nodes = self.__cluster.get_nodes_status() - while len(up_nodes) != self.__number_of_nodes or any( - node["status"] != "HEALTHY" for node in up_nodes - ): - logger.warning( - "Waiting for all nodes to be HEALTHY before reconnecting to batch stream..." - ) - time.sleep(5) - up_nodes = self.__cluster.get_nodes_status() - try: - logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") - self.__connection.close("sync") - self.__connection.connect(force=True) - logger.warning("Reconnected successfully") - except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: - if retry < 5: - time.sleep(2**retry) - self.__reconnect(retry + 1) - else: - logger.error("Failed to reconnect after 5 attempts") - self.__bg_thread_exception = e - - def __start_bg_threads(self) -> _BgThreads: - """Create a background thread that periodically checks how congested the batch queue is.""" - self.__shut_background_thread_down = threading.Event() - - def batch_send_wrapper() -> None: - try: - self.__batch_send() - logger.warning("exited batch send thread") - except Exception as e: - logger.error(e) - self.__bg_thread_exception = e - - def batch_recv_wrapper() -> None: - socket_hung_up = False - try: - self.__batch_recv() - logger.warning("exited batch receive thread") - except Exception as e: - if isinstance(e, WeaviateBatchStreamError) and ( - "Socket closed" in e.message - or "context canceled" in e.message - or "Connection reset" in e.message - or "Received RST_STREAM with error code 2" in e.message - ): - logger.error(f"Socket hung up detected in batch receive thread: {e.message}") - socket_hung_up = True - else: - logger.error(e) - logger.error(type(e)) - self.__bg_thread_exception = e - if socket_hung_up: - # this happens during ungraceful shutdown of the coordinator - # lets restart the stream and add the cached objects again - logger.warning("Stream closed unexpectedly, restarting...") - self.__reconnect() - # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now - self.__is_shutting_down.clear() - with self.__objs_cache_lock: - logger.warning( - f"Re-adding {len(self.__objs_cache)} cached objects to the batch" - ) - self.__batch_objects.prepend(list(self.__objs_cache.values())) - with self.__refs_cache_lock: - self.__batch_references.prepend(list(self.__refs_cache.values())) - # start a new stream with a newly reconnected channel - return batch_recv_wrapper() - - threads = _BgThreads( - send=threading.Thread( - target=batch_send_wrapper, - daemon=True, - name="BgBatchSend", - ), - recv=threading.Thread( - target=batch_recv_wrapper, - daemon=True, - name="BgBatchRecv", - ), - ) - threads.start_recv() - return threads - - def flush(self) -> None: - """Flush the batch queue and wait for all requests to be finished.""" - # bg thread is sending objs+refs automatically, so simply wait for everything to be done - while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: - time.sleep(0.01) - self.__check_bg_threads_alive() - - def _add_object( - self, - collection: str, - properties: Optional[WeaviateProperties] = None, - references: Optional[ReferenceInputs] = None, - uuid: Optional[UUID] = None, - vector: Optional[VECTORS] = None, - tenant: Optional[str] = None, - ) -> UUID: - self.__check_bg_threads_alive() - try: - batch_object = BatchObject( - collection=collection, - properties=properties, - references=references, - uuid=uuid, - vector=vector, - tenant=tenant, - index=self.__objs_count, - ) - self.__results_for_wrapper.imported_shards.add( - Shard(collection=collection, tenant=tenant) - ) - except ValidationError as e: - raise WeaviateBatchValidationError(repr(e)) - uuid = str(batch_object.uuid) - with self.__uuid_lookup_lock: - self.__uuid_lookup.add(uuid) - self.__batch_objects.add(batch_object) - with self.__objs_cache_lock: - self.__objs_cache[uuid] = batch_object - self.__objs_count += 1 - - # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do - # not need a long queue - while len(self.__inflight_objs) >= self.__batch_size: - self.__check_bg_threads_alive() - time.sleep(0.01) - - assert batch_object.uuid is not None - return batch_object.uuid - - def _add_reference( - self, - from_object_uuid: UUID, - from_object_collection: str, - from_property_name: str, - to: ReferenceInput, - tenant: Optional[str] = None, - ) -> None: - self.__check_bg_threads_alive() - if isinstance(to, ReferenceToMulti): - to_strs: Union[List[str], List[UUID]] = to.uuids_str - elif isinstance(to, str) or isinstance(to, uuid_package.UUID): - to_strs = [to] - else: - to_strs = list(to) - - for uid in to_strs: - try: - batch_reference = BatchReference( - from_object_collection=from_object_collection, - from_object_uuid=from_object_uuid, - from_property_name=from_property_name, - to_object_collection=( - to.target_collection if isinstance(to, ReferenceToMulti) else None - ), - to_object_uuid=uid, - tenant=tenant, - index=self.__refs_count, - ) - except ValidationError as e: - raise WeaviateBatchValidationError(repr(e)) - self.__batch_references.add(batch_reference) - with self.__refs_cache_lock: - self.__refs_cache[batch_reference._to_beacon()] = batch_reference - self.__refs_count += 1 - while len(self.__inflight_refs) >= self.__batch_size * 2: - self.__check_bg_threads_alive() - time.sleep(0.01) - - def __check_bg_threads_alive(self) -> None: - if self.__any_threads_alive(): - return - - raise self.__bg_thread_exception or Exception("Batch thread died unexpectedly") + def join(self) -> None: + """Join the background threads.""" + self.loop.join() + self.recv.join() class _ClusterBatch: @@ -1451,3 +918,26 @@ def get_nodes_status( def get_number_of_nodes(self) -> int: return len(self.get_nodes_status()) + + +class _ClusterBatchAsync: + def __init__(self, connection: ConnectionAsync): + self._connection = connection + + async def get_nodes_status( + self, + ) -> List[Node]: + try: + response = await executor.aresult(self._connection.get(path="/nodes")) + except Exception: + return [] + + response_typed = _decode_json_response_dict(response, "Nodes status") + assert response_typed is not None + nodes = response_typed.get("nodes") + if nodes is None or nodes == []: + raise EmptyResponseException("Nodes status response returned empty") + return cast(List[Node], nodes) + + async def get_number_of_nodes(self) -> int: + return len(await self.get_nodes_status()) diff --git a/weaviate/collections/batch/batch_wrapper.py b/weaviate/collections/batch/batch_wrapper.py index a64f267ca..3c1acc827 100644 --- a/weaviate/collections/batch/batch_wrapper.py +++ b/weaviate/collections/batch/batch_wrapper.py @@ -1,14 +1,18 @@ +import asyncio import time from typing import Any, Generic, List, Optional, Protocol, TypeVar, Union, cast +from weaviate.collections.batch.async_ import _BatchBaseAsync from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _BatchMode, _ClusterBatch, + _ClusterBatchAsync, _DynamicBatching, + _ServerSideBatching, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.batch import ( BatchResult, ErrorObject, @@ -20,7 +24,7 @@ from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import Properties, WeaviateProperties from weaviate.connect import executor -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.logger import logger from weaviate.types import UUID, VECTORS from weaviate.util import _capitalize_first_letter, _decode_json_response_list @@ -34,7 +38,7 @@ def __init__( ): self._connection = connection self._consistency_level = consistency_level - self._current_batch: Optional[Union[_BatchBase, _BatchBaseNew]] = None + self._current_batch: Optional[Union[_BatchBase, _BatchBaseSync]] = None # config options self._batch_mode: _BatchMode = _DynamicBatching() @@ -127,6 +131,109 @@ def results(self) -> BatchResult: return self._batch_data.results +class _BatchWrapperAsync: + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + ): + self._connection = connection + self._consistency_level = consistency_level + self._current_batch: Optional[_BatchBaseAsync] = None + # config options + self._batch_mode: _BatchMode = _ServerSideBatching(1) + + self._batch_data = _BatchDataWrapper() + self._cluster = _ClusterBatchAsync(connection) + + async def __is_ready( + self, max_count: int, shards: Optional[List[Shard]], backoff_count: int = 0 + ) -> bool: + try: + readinesses = await asyncio.gather( + *[ + self.__get_shards_readiness(shard) + for shard in shards or self._batch_data.imported_shards + ] + ) + return all(all(readiness) for readiness in readinesses) + except Exception as e: + logger.warning( + f"Error while getting class shards statuses: {e}, trying again with 2**n={2**backoff_count}s exponential backoff with n={backoff_count}" + ) + if backoff_count >= max_count: + raise e + await asyncio.sleep(2**backoff_count) + return await self.__is_ready(max_count, shards, backoff_count + 1) + + async def wait_for_vector_indexing( + self, shards: Optional[List[Shard]] = None, how_many_failures: int = 5 + ) -> None: + """Wait for the all the vectors of the batch imported objects to be indexed. + + Upon network error, it will retry to get the shards' status for `how_many_failures` times + with exponential backoff (2**n seconds with n=0,1,2,...,how_many_failures). + + Args: + shards: The shards to check the status of. If `None` it will check the status of all the shards of the imported objects in the batch. + how_many_failures: How many times to try to get the shards' status before raising an exception. Default 5. + """ + if shards is not None and not isinstance(shards, list): + raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") + if shards is not None and not isinstance(shards[0], Shard): + raise TypeError(f"'shards' must be of type List[Shard]. Given type: {type(shards)}.") + + waiting_count = 0 + while not await self.__is_ready(how_many_failures, shards): + if waiting_count % 20 == 0: # print every 5s + logger.debug("Waiting for async indexing to finish...") + await asyncio.sleep(0.25) + waiting_count += 1 + logger.debug("Async indexing finished!") + + async def __get_shards_readiness(self, shard: Shard) -> List[bool]: + path = f"/schema/{_capitalize_first_letter(shard.collection)}/shards{'' if shard.tenant is None else f'?tenant={shard.tenant}'}" + response = await executor.aresult(self._connection.get(path=path)) + + res = _decode_json_response_list(response, "Get shards' status") + assert res is not None + return [ + (cast(str, shard.get("status")) == "READY") + & (cast(int, shard.get("vectorQueueSize")) == 0) + for shard in res + ] + + async def _get_shards_readiness(self, shard: Shard) -> List[bool]: + return await self.__get_shards_readiness(shard) + + @property + def failed_objects(self) -> List[ErrorObject]: + """Get all failed objects from the batch manager. + + Returns: + A list of all the failed objects from the batch. + """ + return self._batch_data.failed_objects + + @property + def failed_references(self) -> List[ErrorReference]: + """Get all failed references from the batch manager. + + Returns: + A list of all the failed references from the batch. + """ + return self._batch_data.failed_references + + @property + def results(self) -> BatchResult: + """Get the results of the batch operation. + + Returns: + The results of the batch operation. + """ + return self._batch_data.results + + class BatchClientProtocol(Protocol): def add_object( self, @@ -204,6 +311,83 @@ def number_errors(self) -> int: ... +class BatchClientProtocolAsync(Protocol): + async def add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[Union[str, Tenant]] = None, + ) -> UUID: + """Add one object to this batch. + + NOTE: If the UUID of one of the objects already exists then the existing object will be + replaced by the new object. + + Args: + collection: The name of the collection this object belongs to. + properties: The data properties of the object to be added as a dictionary. + references: The references of the object to be added as a dictionary. + uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. + If it is None an UUIDv4 will generated, by default None + vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given + vector was generated using the _identical_ vectorization module that is configured for the class. In this + case this vector takes precedence. + Supported types are: + - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. + - for named vectors: Dict[str, *list above*], where the string is the name of the vector. + tenant: The tenant name or Tenant object to be used for this request. + + Returns: + The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + async def add_reference( + self, + from_uuid: UUID, + from_collection: str, + from_property: str, + to: ReferenceInput, + tenant: Optional[Union[str, Tenant]] = None, + ) -> None: + """Add one reference to this batch. + + Args: + from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. + from_collection: The name of the collection that should reference another object. + from_property: The name of the property that contains the reference. + to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. + For multi-target references use wvc.Reference.to_multi_target(). + tenant: The tenant name or Tenant object to be used for this request. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + async def flush(self) -> None: + """Flush the current batch. + + This will send all the objects and references in the current batch to Weaviate. + """ + ... + + @property + def number_errors(self) -> int: + """Get the number of errors in the current batch. + + Returns: + The number of errors in the current batch. + """ + ... + + class BatchCollectionProtocol(Generic[Properties], Protocol[Properties]): def add_object( self, @@ -260,8 +444,65 @@ def number_errors(self) -> int: ... -T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseNew]) +class BatchCollectionProtocolAsync(Generic[Properties], Protocol[Properties]): + async def add_object( + self, + properties: Optional[Properties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + ) -> UUID: + """Add one object to this batch. + + NOTE: If the UUID of one of the objects already exists then the existing object will be replaced by the new object. + + Args: + properties: The data properties of the object to be added as a dictionary. + references: The references of the object to be added as a dictionary. + uuid: The UUID of the object as an uuid.UUID object or str. If it is None an UUIDv4 will generated, by default None + vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given + vector was generated using the _identical_ vectorization module that is configured for the class. In this + case this vector takes precedence. Supported types are: + - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. + - for named vectors: Dict[str, *list above*], where the string is the name of the vector. + + Returns: + The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + async def add_reference( + self, from_uuid: UUID, from_property: str, to: Union[ReferenceInput, List[UUID]] + ) -> None: + """Add a reference to this batch. + + Args: + from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. + from_property: The name of the property that contains the reference. + to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. + For multi-target references use wvc.Reference.to_multi_target(). + + Raises: + WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. + """ + ... + + @property + def number_errors(self) -> int: + """Get the number of errors in the current batch. + + Returns: + The number of errors in the current batch. + """ + ... + + +T = TypeVar("T", bound=Union[_BatchBase, _BatchBaseSync]) P = TypeVar("P", bound=Union[BatchClientProtocol, BatchCollectionProtocol[Properties]]) +Q = TypeVar("Q", bound=Union[BatchClientProtocolAsync, BatchCollectionProtocolAsync[Properties]]) class _ContextManagerWrapper(Generic[T, P]): @@ -274,3 +515,16 @@ def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: def __enter__(self) -> P: self.__current_batch._start() return self.__current_batch # pyright: ignore[reportReturnType] + + +class _ContextManagerWrapperAsync(Generic[Q]): + def __init__(self, current_batch: _BatchBaseAsync): + self.__current_batch = current_batch + + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + await self.__current_batch._shutdown() + await self.__current_batch._wait() + + async def __aenter__(self) -> Q: + await self.__current_batch._start() + return self.__current_batch # pyright: ignore[reportReturnType] diff --git a/weaviate/collections/batch/client.py b/weaviate/collections/batch/client.py index a86a3be10..ca4126d8d 100644 --- a/weaviate/collections/batch/client.py +++ b/weaviate/collections/batch/client.py @@ -1,9 +1,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Optional, Type, Union +from weaviate.collections.batch.async_ import _BatchBaseAsync from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _DynamicBatching, _FixedSizeBatching, @@ -12,15 +12,19 @@ ) from weaviate.collections.batch.batch_wrapper import ( BatchClientProtocol, + BatchClientProtocolAsync, _BatchMode, _BatchWrapper, + _BatchWrapperAsync, _ContextManagerWrapper, + _ContextManagerWrapperAsync, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs from weaviate.collections.classes.tenants import Tenant from weaviate.collections.classes.types import WeaviateProperties -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateUnsupportedFeatureError from weaviate.types import UUID, VECTORS @@ -38,31 +42,6 @@ def add_object( vector: Optional[VECTORS] = None, tenant: Optional[Union[str, Tenant]] = None, ) -> UUID: - """Add one object to this batch. - - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Args: - collection: The name of the collection this object belongs to. - properties: The data properties of the object to be added as a dictionary. - references: The references of the object to be added as a dictionary. - uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. - If it is None an UUIDv4 will generated, by default None - vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given - vector was generated using the _identical_ vectorization module that is configured for the class. In this - case this vector takes precedence. - Supported types are: - - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. - - for named vectors: Dict[str, *list above*], where the string is the name of the vector. - tenant: The tenant name or Tenant object to be used for this request. - - Returns: - The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ return super()._add_object( collection=collection, properties=properties, @@ -80,19 +59,6 @@ def add_reference( to: ReferenceInput, tenant: Optional[Union[str, Tenant]] = None, ) -> None: - """Add one reference to this batch. - - Args: - from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. - from_collection: The name of the collection that should reference another object. - from_property: The name of the property that contains the reference. - to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. - For multi-target references use wvc.Reference.to_multi_target(). - tenant: The tenant name or Tenant object to be used for this request. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ super()._add_reference( from_object_uuid=from_uuid, from_object_collection=from_collection, @@ -102,7 +68,7 @@ def add_reference( ) -class _BatchClientNew(_BatchBaseNew): +class _BatchClientSync(_BatchBaseSync): def add_object( self, collection: str, @@ -112,31 +78,6 @@ def add_object( vector: Optional[VECTORS] = None, tenant: Optional[Union[str, Tenant]] = None, ) -> UUID: - """Add one object to this batch. - - NOTE: If the UUID of one of the objects already exists then the existing object will be - replaced by the new object. - - Args: - collection: The name of the collection this object belongs to. - properties: The data properties of the object to be added as a dictionary. - references: The references of the object to be added as a dictionary. - uuid: The UUID of the object as an uuid.UUID object or str. It can be a Weaviate beacon or Weaviate href. - If it is None an UUIDv4 will generated, by default None - vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given - vector was generated using the _identical_ vectorization module that is configured for the class. In this - case this vector takes precedence. - Supported types are: - - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. - - for named vectors: Dict[str, *list above*], where the string is the name of the vector. - tenant: The tenant name or Tenant object to be used for this request. - - Returns: - The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ return super()._add_object( collection=collection, properties=properties, @@ -154,19 +95,6 @@ def add_reference( to: ReferenceInput, tenant: Optional[Union[str, Tenant]] = None, ) -> None: - """Add one reference to this batch. - - Args: - from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. - from_collection: The name of the collection that should reference another object. - from_property: The name of the property that contains the reference. - to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. - For multi-target references use wvc.Reference.to_multi_target(). - tenant: The tenant name or Tenant object to be used for this request. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ super()._add_reference( from_object_uuid=from_uuid, from_object_collection=from_collection, @@ -176,11 +104,49 @@ def add_reference( ) +class _BatchClientAsync(_BatchBaseAsync): + async def add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[Union[str, Tenant]] = None, + ) -> UUID: + return await super()._add_object( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant.name if isinstance(tenant, Tenant) else tenant, + ) + + async def add_reference( + self, + from_uuid: UUID, + from_collection: str, + from_property: str, + to: ReferenceInput, + tenant: Optional[Union[str, Tenant]] = None, + ) -> None: + await super()._add_reference( + from_object_uuid=from_uuid, + from_object_collection=from_collection, + from_property_name=from_property, + to=to, + tenant=tenant.name if isinstance(tenant, Tenant) else tenant, + ) + + BatchClient = _BatchClient -BatchClientNew = _BatchClientNew +BatchClientSync = _BatchClientSync +BatchClientAsync = _BatchClientAsync ClientBatchingContextManager = _ContextManagerWrapper[ - Union[BatchClient, BatchClientNew], BatchClientProtocol + Union[BatchClient, BatchClientSync], BatchClientProtocol ] +AsyncClientBatchingContextManager = _ContextManagerWrapperAsync[BatchClientProtocolAsync] class _BatchClientWrapper(_BatchWrapper): @@ -197,7 +163,7 @@ def __init__( # define one executor per client with it shared between all child batch contexts def __create_batch_and_reset( - self, batch_client: Union[Type[_BatchClient], Type[_BatchClientNew]] + self, batch_client: Union[Type[_BatchClient], Type[_BatchClientSync]] ): if self._vectorizer_batching is None or not self._vectorizer_batching: try: @@ -311,4 +277,46 @@ def experimental( concurrency=1, # hard-code until client-side multi-threading is fixed ) self._consistency_level = consistency_level - return self.__create_batch_and_reset(_BatchClientNew) + return self.__create_batch_and_reset(_BatchClientSync) + + +class _BatchClientWrapperAsync(_BatchWrapperAsync): + def __init__( + self, + connection: ConnectionAsync, + ): + super().__init__(connection, None) + self._vectorizer_batching: Optional[bool] = None + + def __create_batch_and_reset(self): + self._batch_data = _BatchDataWrapper() # clear old data + return _ContextManagerWrapperAsync( + BatchClientAsync( + connection=self._connection, + consistency_level=self._consistency_level, + results=self._batch_data, + ) + ) + + def experimental( + self, + *, + concurrency: Optional[int] = None, + consistency_level: Optional[ConsistencyLevel] = None, + ) -> AsyncClientBatchingContextManager: + """Configure the batching context manager using the experimental server-side batching mode. + + When you exit the context manager, the final batch will be sent automatically. + """ + if self._connection._weaviate_version.is_lower_than(1, 34, 0): + raise WeaviateUnsupportedFeatureError( + "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + ) + self._batch_mode = _ServerSideBatching( + # concurrency=concurrency + # if concurrency is not None + # else len(self._cluster.get_nodes_status()) + concurrency=1, # hard-code until client-side multi-threading is fixed + ) + self._consistency_level = consistency_level + return self.__create_batch_and_reset() diff --git a/weaviate/collections/batch/collection.py b/weaviate/collections/batch/collection.py index 6abe4aaac..e06531680 100644 --- a/weaviate/collections/batch/collection.py +++ b/weaviate/collections/batch/collection.py @@ -1,9 +1,9 @@ from concurrent.futures import ThreadPoolExecutor from typing import TYPE_CHECKING, Generic, List, Optional, Type, Union +from weaviate.collections.batch.async_ import _BatchBaseAsync from weaviate.collections.batch.base import ( _BatchBase, - _BatchBaseNew, _BatchDataWrapper, _BatchMode, _DynamicBatching, @@ -13,13 +13,17 @@ ) from weaviate.collections.batch.batch_wrapper import ( BatchCollectionProtocol, + BatchCollectionProtocolAsync, _BatchWrapper, + _BatchWrapperAsync, _ContextManagerWrapper, + _ContextManagerWrapperAsync, ) +from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers from weaviate.collections.classes.internal import ReferenceInput, ReferenceInputs from weaviate.collections.classes.types import Properties -from weaviate.connect.v4 import ConnectionSync +from weaviate.connect.v4 import ConnectionAsync, ConnectionSync from weaviate.exceptions import UnexpectedStatusCodeError, WeaviateUnsupportedFeatureError from weaviate.types import UUID, VECTORS @@ -78,7 +82,7 @@ def add_reference( ) -class _BatchCollectionNew(Generic[Properties], _BatchBaseNew): +class _BatchCollectionSync(Generic[Properties], _BatchBaseSync): def __init__( self, executor: ThreadPoolExecutor, @@ -108,26 +112,6 @@ def add_object( uuid: Optional[UUID] = None, vector: Optional[VECTORS] = None, ) -> UUID: - """Add one object to this batch. - - NOTE: If the UUID of one of the objects already exists then the existing object will be replaced by the new object. - - Args: - properties: The data properties of the object to be added as a dictionary. - references: The references of the object to be added as a dictionary. - uuid: The UUID of the object as an uuid.UUID object or str. If it is None an UUIDv4 will generated, by default None - vector: The embedding of the object. Can be used when a collection does not have a vectorization module or the given - vector was generated using the _identical_ vectorization module that is configured for the class. In this - case this vector takes precedence. Supported types are: - - for single vectors: `list`, 'numpy.ndarray`, `torch.Tensor` and `tf.Tensor`, by default None. - - for named vectors: Dict[str, *list above*], where the string is the name of the vector. - - Returns: - The UUID of the added object. If one was not provided a UUIDv4 will be auto-generated for you and returned here. - - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ return self._add_object( collection=self.__name, properties=properties, @@ -140,18 +124,52 @@ def add_object( def add_reference( self, from_uuid: UUID, from_property: str, to: Union[ReferenceInput, List[UUID]] ) -> None: - """Add a reference to this batch. + self._add_reference( + from_uuid, + self.__name, + from_property, + to, + self.__tenant, + ) - Args: - from_uuid: The UUID of the object, as an uuid.UUID object or str, that should reference another object. - from_property: The name of the property that contains the reference. - to: The UUID of the referenced object, as an uuid.UUID object or str, that is actually referenced. - For multi-target references use wvc.Reference.to_multi_target(). - Raises: - WeaviateBatchValidationError: If the provided options are in the format required by Weaviate. - """ - self._add_reference( +class _BatchCollectionAsync(Generic[Properties], _BatchBaseAsync): + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + results: _BatchDataWrapper, + name: str, + tenant: Optional[str], + ) -> None: + super().__init__( + connection=connection, + consistency_level=consistency_level, + results=results, + ) + self.__name = name + self.__tenant = tenant + + async def add_object( + self, + properties: Optional[Properties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + ) -> UUID: + return await self._add_object( + collection=self.__name, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=self.__tenant, + ) + + async def add_reference( + self, from_uuid: UUID, from_property: str, to: Union[ReferenceInput, List[UUID]] + ) -> None: + await self._add_reference( from_uuid, self.__name, from_property, @@ -161,11 +179,15 @@ def add_reference( BatchCollection = _BatchCollection -BatchCollectionNew = _BatchCollectionNew +BatchCollectionSync = _BatchCollectionSync +BatchCollectionAsync = _BatchCollectionAsync CollectionBatchingContextManager = _ContextManagerWrapper[ - Union[BatchCollection[Properties], BatchCollectionNew[Properties]], + Union[BatchCollection[Properties], BatchCollectionSync[Properties]], BatchCollectionProtocol[Properties], ] +CollectionBatchingContextManagerAsync = _ContextManagerWrapperAsync[ + BatchCollectionProtocolAsync[Properties] +] class _BatchCollectionWrapper(Generic[Properties], _BatchWrapper): @@ -177,7 +199,7 @@ def __init__( tenant: Optional[str], config: "_ConfigCollection", batch_client: Union[ - Type[_BatchCollection[Properties]], Type[_BatchCollectionNew[Properties]] + Type[_BatchCollection[Properties]], Type[_BatchCollectionSync[Properties]] ], ) -> None: super().__init__(connection, consistency_level) @@ -192,7 +214,7 @@ def __init__( def __create_batch_and_reset( self, batch_client: Union[ - Type[_BatchCollection[Properties]], Type[_BatchCollectionNew[Properties]] + Type[_BatchCollection[Properties]], Type[_BatchCollectionSync[Properties]] ], ): if self._vectorizer_batching is None: @@ -278,4 +300,48 @@ def experimental( # else len(self._cluster.get_nodes_status()) concurrency=1, # hard-code until client-side multi-threading is fixed ) - return self.__create_batch_and_reset(_BatchCollectionNew) + return self.__create_batch_and_reset(_BatchCollectionSync) + + +class _BatchCollectionWrapperAsync(Generic[Properties], _BatchWrapperAsync): + def __init__( + self, + connection: ConnectionAsync, + consistency_level: Optional[ConsistencyLevel], + name: str, + tenant: Optional[str], + ) -> None: + super().__init__(connection, consistency_level) + self.__name = name + self.__tenant = tenant + + def __create_batch_and_reset(self): + self._batch_data = _BatchDataWrapper() # clear old data + return _ContextManagerWrapperAsync( + BatchCollectionAsync( + connection=self._connection, + consistency_level=self._consistency_level, + results=self._batch_data, + name=self.__name, + tenant=self.__tenant, + ) + ) + + def experimental( + self, + ) -> CollectionBatchingContextManagerAsync[Properties]: + """Configure the batching context manager using the experimental server-side batching mode. + + When you exit the context manager, the final batch will be sent automatically. + """ + if self._connection._weaviate_version.is_lower_than(1, 34, 0): + raise WeaviateUnsupportedFeatureError( + "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + ) + self._batch_mode = _ServerSideBatching( + # concurrency=concurrency + # if concurrency is not None + # else len(self._cluster.get_nodes_status()) + concurrency=1, # hard-code until client-side multi-threading is fixed + ) + return self.__create_batch_and_reset() diff --git a/weaviate/collections/batch/grpc_batch.py b/weaviate/collections/batch/grpc_batch.py index 7384dcb49..6f01b2287 100644 --- a/weaviate/collections/batch/grpc_batch.py +++ b/weaviate/collections/batch/grpc_batch.py @@ -20,7 +20,7 @@ from weaviate.collections.grpc.shared import _BaseGRPC, _is_1d_vector, _Pack from weaviate.connect import executor from weaviate.connect.base import MAX_GRPC_MESSAGE_LENGTH -from weaviate.connect.v4 import Connection, ConnectionSync +from weaviate.connect.v4 import Connection, ConnectionAsync, ConnectionSync from weaviate.exceptions import ( WeaviateInsertInvalidPropertyError, WeaviateInsertManyAllFailedError, @@ -203,8 +203,8 @@ def stream( connection: ConnectionSync, *, requests: Generator[batch_pb2.BatchStreamRequest, None, None], - ) -> Generator[batch_pb2.BatchStreamReply, None, None]: - """Start a new stream for receiving messages about the ongoing server-side batching from Weaviate. + ): + """Start a new sync stream for send/recv messages about the ongoing server-side batching from Weaviate. Args: connection: The connection to the Weaviate instance. @@ -212,6 +212,17 @@ def stream( """ return connection.grpc_batch_stream(requests=requests) + def astream( + self, + connection: ConnectionAsync, + ): + """Start a new async stream for send/recv messages about the ongoing server-side batching from Weaviate. + + Args: + connection: The connection to the Weaviate instance. + """ + return connection.grpc_batch_stream() + def __translate_properties_from_python_to_grpc( self, data: Dict[str, Any], refs: ReferenceInputs ) -> batch_pb2.BatchObject.Properties: diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py new file mode 100644 index 000000000..54322ee9d --- /dev/null +++ b/weaviate/collections/batch/sync.py @@ -0,0 +1,602 @@ +import threading +import time +import uuid as uuid_package +from concurrent.futures import ThreadPoolExecutor +from queue import Queue +from typing import Generator, List, Optional, Set, Union + +from pydantic import ValidationError + +from weaviate.collections.batch.base import ( + GCP_STREAM_TIMEOUT, + ObjectsBatchRequest, + ReferencesBatchRequest, + _BatchDataWrapper, + _BatchMode, + _BgThreads, + _ClusterBatch, + _ServerSideBatching, +) +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.classes.batch import ( + BatchObject, + BatchObjectReturn, + BatchReference, + BatchReferenceReturn, + ErrorObject, + ErrorReference, + Shard, +) +from weaviate.collections.classes.config import ConsistencyLevel +from weaviate.collections.classes.internal import ( + ReferenceInput, + ReferenceInputs, + ReferenceToMulti, +) +from weaviate.collections.classes.types import WeaviateProperties +from weaviate.connect.v4 import ConnectionSync +from weaviate.exceptions import ( + WeaviateBatchStreamError, + WeaviateBatchValidationError, + WeaviateGRPCUnavailableError, + WeaviateStartUpError, +) +from weaviate.logger import logger +from weaviate.proto.v1 import batch_pb2 +from weaviate.types import UUID, VECTORS + + +class _BatchBaseSync: + def __init__( + self, + connection: ConnectionSync, + consistency_level: Optional[ConsistencyLevel], + results: _BatchDataWrapper, + batch_mode: _BatchMode, + executor: ThreadPoolExecutor, + vectorizer_batching: bool, + objects: Optional[ObjectsBatchRequest[BatchObject]] = None, + references: Optional[ReferencesBatchRequest[BatchReference]] = None, + ) -> None: + self.__batch_objects = objects or ObjectsBatchRequest[BatchObject]() + self.__batch_references = references or ReferencesBatchRequest[BatchReference]() + + self.__connection = connection + self.__is_gcp_on_wcd = connection._connection_params.is_gcp_on_wcd() + self.__stream_start: Optional[float] = None + self.__is_renewing_stream = threading.Event() + self.__consistency_level: ConsistencyLevel = consistency_level or ConsistencyLevel.QUORUM + self.__batch_size = 100 + + self.__batch_grpc = _BatchGRPC( + connection._weaviate_version, self.__consistency_level, connection._grpc_max_msg_size + ) + self.__cluster = _ClusterBatch(self.__connection) + self.__number_of_nodes = self.__cluster.get_number_of_nodes() + + # lookup table for objects that are currently being processed - is used to not send references from objects that have not been added yet + self.__uuid_lookup: Set[str] = set() + + # we do not want that users can access the results directly as they are not thread-safe + self.__results_for_wrapper_backup = results + self.__results_for_wrapper = _BatchDataWrapper() + + self.__objs_count = 0 + self.__refs_count = 0 + + self.__uuid_lookup_lock = threading.Lock() + self.__results_lock = threading.Lock() + + self.__bg_exception: Optional[Exception] = None + self.__is_oom = threading.Event() + self.__is_shutting_down = threading.Event() + self.__is_shutdown = threading.Event() + + self.__objs_cache_lock = threading.Lock() + self.__refs_cache_lock = threading.Lock() + self.__objs_cache: dict[str, BatchObject] = {} + self.__refs_cache: dict[str, BatchReference] = {} + + self.__acks_lock = threading.Lock() + self.__inflight_objs: set[str] = set() + self.__inflight_refs: set[str] = set() + + # maxsize=1 so that __send does not run faster than generator for __recv + # thereby using too much buffer in case of server-side shutdown + self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) + + self.__stop = False + + self.__batch_mode = batch_mode + + @property + def number_errors(self) -> int: + """Return the number of errors in the batch.""" + return len(self.__results_for_wrapper.failed_objects) + len( + self.__results_for_wrapper.failed_references + ) + + def __all_threads_alive(self) -> bool: + return self.__bg_threads is not None and all( + thread.is_alive() for thread in self.__bg_threads + ) + + def __any_threads_alive(self) -> bool: + return self.__bg_threads is not None and any( + thread.is_alive() for thread in self.__bg_threads + ) + + def _start(self) -> None: + assert isinstance(self.__batch_mode, _ServerSideBatching), ( + "Only server-side batching is supported in this mode" + ) + self.__bg_threads = [ + self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) + ] + logger.warning( + f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" + ) + now = time.time() + while not self.__all_threads_alive(): + # wait for the stream to be started by __batch_stream + time.sleep(0.01) + if time.time() - now > 60: + raise WeaviateBatchStreamError( + "Batch stream was not started within 60 seconds. Please check your connection." + ) + + def _wait(self) -> None: + for bg_thread in self.__bg_threads: + bg_thread.join() + + def _shutdown(self) -> None: + # Shutdown the current batch and wait for all requests to be finished + self.__stop = True + + # copy the results to the public results + self.__results_for_wrapper_backup.results = self.__results_for_wrapper.results + self.__results_for_wrapper_backup.failed_objects = self.__results_for_wrapper.failed_objects + self.__results_for_wrapper_backup.failed_references = ( + self.__results_for_wrapper.failed_references + ) + self.__results_for_wrapper_backup.imported_shards = ( + self.__results_for_wrapper.imported_shards + ) + + def __loop(self) -> None: + refresh_time: float = 0.01 + while self.__bg_exception is None: + if len(self.__batch_objects) + len(self.__batch_references) > 0: + self._batch_send = True + start = time.time() + while (len_o := len(self.__batch_objects)) + ( + len_r := len(self.__batch_references) + ) < self.__batch_size: + # wait for more objects to be added up to the batch size + time.sleep(0.01) + if time.time() - start >= 1 and ( + len_o == len(self.__batch_objects) or len_r == len(self.__batch_references) + ): + # no new objects were added in the last second, exit the loop + break + + objs = self.__batch_objects.pop_items(self.__batch_size) + refs = self.__batch_references.pop_items( + self.__batch_size - len(objs), + uuid_lookup=self.__uuid_lookup, + ) + with self.__uuid_lookup_lock: + self.__uuid_lookup.difference_update(obj.uuid for obj in objs) + + for req in self.__generate_stream_requests(objs, refs): + self.__reqs.put(req) + elif self.__stop: + # we are done, send the sentinel into our queue to be consumed by the batch sender + self.__reqs.put(None) # signal the end of the stream + logger.warning("Batching finished, sent stop signal to batch stream") + return + time.sleep(refresh_time) + + def __beacon(self, ref: batch_pb2.BatchReference) -> str: + return f"weaviate://localhost/{ref.from_collection}{f'#{ref.tenant}' if ref.tenant != '' else ''}/{ref.from_uuid}#{ref.name}->/{ref.to_collection}/{ref.to_uuid}" + + def __generate_stream_requests( + self, + objects: List[BatchObject], + references: List[BatchReference], + ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + per_object_overhead = 4 # extra overhead bytes per object in the request + + def request_maker(): + return batch_pb2.BatchStreamRequest() + + request = request_maker() + total_size = request.ByteSize() + + inflight_objs = set() + inflight_refs = set() + for object_ in objects: + obj = self.__batch_grpc.grpc_object(object_._to_internal()) + obj_size = obj.ByteSize() + per_object_overhead + + if total_size + obj_size >= self.__batch_grpc.grpc_max_msg_size: + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.objects.values.append(obj) + total_size += obj_size + inflight_objs.add(obj.uuid) + + for reference in references: + ref = self.__batch_grpc.grpc_reference(reference._to_internal()) + ref_size = ref.ByteSize() + per_object_overhead + + if total_size + ref_size >= self.__batch_grpc.grpc_max_msg_size: + yield request + request = request_maker() + total_size = request.ByteSize() + + request.data.references.values.append(ref) + total_size += ref_size + inflight_refs.add(reference._to_beacon()) + + with self.__acks_lock: + self.__inflight_objs.update(inflight_objs) + self.__inflight_refs.update(inflight_refs) + + if len(request.data.objects.values) > 0 or len(request.data.references.values) > 0: + yield request + + def __send( + self, + ) -> Generator[batch_pb2.BatchStreamRequest, None, None]: + yield batch_pb2.BatchStreamRequest( + start=batch_pb2.BatchStreamRequest.Start( + consistency_level=self.__batch_grpc._consistency_level, + ), + ) + while self.__bg_exception is not None: + if self.__is_gcp_on_wcd: + assert self.__stream_start is not None + if time.time() - self.__stream_start > GCP_STREAM_TIMEOUT: + logger.warning( + "GCP connections have a maximum lifetime. Re-establishing the batch stream to avoid timeout errors." + ) + yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) + self.__is_renewing_stream.set() + return + req = self.__reqs.get() + if req is not None: + yield req + continue + if self.__stop and not ( + self.__is_shutting_down.is_set() or self.__is_shutdown.is_set() + ): + logger.warning("Batching finished, closing the client-side of the stream") + yield batch_pb2.BatchStreamRequest(stop=batch_pb2.BatchStreamRequest.Stop()) + return + if self.__is_shutting_down.is_set(): + logger.warning("Server shutting down, closing the client-side of the stream") + return + if self.__is_oom.is_set(): + logger.warning("Server out-of-memory, closing the client-side of the stream") + return + logger.warning("Received sentinel, but not stopping, continuing...") + + def __recv(self) -> None: + for message in self.__batch_grpc.stream( + connection=self.__connection, + requests=self.__send(), + ): + if message.HasField("started"): + logger.warning("Batch stream started successfully") + for threads in self.__bg_threads: + threads.start_loop() + if message.HasField("backoff"): + if ( + message.backoff.batch_size != self.__batch_size + and not self.__is_shutting_down.is_set() + and not self.__is_shutdown.is_set() + and not self.__stop + ): + self.__batch_size = message.backoff.batch_size + logger.warning( + f"Updated batch size to {self.__batch_size} as per server request" + ) + if message.HasField("acks"): + with self.__acks_lock: + self.__inflight_objs.difference_update(message.acks.uuids) + self.__uuid_lookup.difference_update(message.acks.uuids) + self.__inflight_refs.difference_update(message.acks.beacons) + if message.HasField("results"): + result_objs = BatchObjectReturn() + result_refs = BatchReferenceReturn() + failed_objs: List[ErrorObject] = [] + failed_refs: List[ErrorReference] = [] + for error in message.results.errors: + if error.HasField("uuid"): + try: + with self.__objs_cache_lock: + cached = self.__objs_cache.pop(error.uuid) + except KeyError: + continue + err = ErrorObject( + message=error.error, + object_=cached, + ) + result_objs += BatchObjectReturn( + _all_responses=[err], + errors={cached.index: err}, + ) + failed_objs.append(err) + logger.warning( + { + "error": error.error, + "object": error.uuid, + "action": "use {client,collection}.batch.failed_objects to access this error", + } + ) + if error.HasField("beacon"): + try: + with self.__refs_cache_lock: + cached = self.__refs_cache.pop(error.beacon) + except KeyError: + continue + err = ErrorReference( + message=error.error, + reference=cached, + ) + failed_refs.append(err) + result_refs += BatchReferenceReturn( + errors={cached.index: err}, + ) + logger.warning( + { + "error": error.error, + "reference": error.beacon, + "action": "use {client,collection}.batch.failed_references to access this error", + } + ) + for success in message.results.successes: + if success.HasField("uuid"): + try: + with self.__objs_cache_lock: + cached = self.__objs_cache.pop(success.uuid) + except KeyError: + continue + uuid = uuid_package.UUID(success.uuid) + result_objs += BatchObjectReturn( + _all_responses=[uuid], + uuids={cached.index: uuid}, + ) + if success.HasField("beacon"): + try: + with self.__refs_cache_lock: + self.__refs_cache.pop(success.beacon, None) + except KeyError: + continue + with self.__results_lock: + self.__results_for_wrapper.results.objs += result_objs + self.__results_for_wrapper.results.refs += result_refs + self.__results_for_wrapper.failed_objects.extend(failed_objs) + self.__results_for_wrapper.failed_references.extend(failed_refs) + if message.HasField("out_of_memory"): + logger.warning( + "Server reported out-of-memory error. Batching will wait at most 10 minutes for the server to scale-up. If the server does not recover within this time, the batch will terminate with an error." + ) + self.__is_oom.set() + with self.__objs_cache_lock: + self.__batch_objects.prepend( + [self.__objs_cache[uuid] for uuid in message.out_of_memory.uuids] + ) + with self.__refs_cache_lock: + self.__batch_references.prepend( + [self.__refs_cache[beacon] for beacon in message.out_of_memory.beacons] + ) + if message.HasField("shutting_down"): + logger.warning( + "Received shutting down message from server, pausing sending until stream is re-established" + ) + self.__is_shutting_down.set() + self.__is_oom.clear() + if message.HasField("shutdown"): + logger.warning("Received shutdown finished message from server") + self.__is_shutdown.set() + self.__is_shutting_down.clear() + self.__reconnect() + + # restart the stream if we were shutdown by the node we were connected to ensuring that the index is + # propagated properly from it to the new one + if self.__is_shutdown.is_set(): + logger.warning("Restarting batch recv after shutdown...") + self.__is_shutdown.clear() + return self.__recv() + elif self.__is_renewing_stream.is_set(): + # restart the stream if we are renewing it (GCP connections have a max lifetime) + logger.warning("Restarting batch recv after renewing stream...") + self.__is_renewing_stream.clear() + return self.__recv() + else: + logger.warning("Server closed the stream from its side, shutting down batch") + return + + def __reconnect(self, retry: int = 0) -> None: + if self.__consistency_level == ConsistencyLevel.ALL or self.__number_of_nodes == 1: + # check that all nodes are available before reconnecting + up_nodes = self.__cluster.get_nodes_status() + while len(up_nodes) != self.__number_of_nodes or any( + node["status"] != "HEALTHY" for node in up_nodes + ): + logger.warning( + "Waiting for all nodes to be HEALTHY before reconnecting to batch stream..." + ) + time.sleep(5) + up_nodes = self.__cluster.get_nodes_status() + try: + logger.warning(f"Trying to reconnect after shutdown... {retry + 1}/{5}") + self.__connection.close("sync") + self.__connection.connect(force=True) + logger.warning("Reconnected successfully") + except (WeaviateStartUpError, WeaviateGRPCUnavailableError) as e: + if retry < 5: + time.sleep(2**retry) + self.__reconnect(retry + 1) + else: + logger.error("Failed to reconnect after 5 attempts") + self.__bg_exception = e + + def __start_bg_threads(self) -> _BgThreads: + """Create a background thread that periodically checks how congested the batch queue is.""" + + def loop_wrapper() -> None: + try: + self.__loop() + logger.warning("exited batch requests loop thread") + except Exception as e: + logger.error(e) + self.__bg_exception = e + + def recv_wrapper() -> None: + socket_hung_up = False + try: + self.__recv() + logger.warning("exited batch receive thread") + except Exception as e: + if isinstance(e, WeaviateBatchStreamError) and ( + "Socket closed" in e.message + or "context canceled" in e.message + or "Connection reset" in e.message + or "Received RST_STREAM with error code 2" in e.message + ): + logger.error(f"Socket hung up detected in batch receive thread: {e.message}") + socket_hung_up = True + else: + logger.error(e) + logger.error(type(e)) + self.__bg_exception = e + if socket_hung_up: + # this happens during ungraceful shutdown of the coordinator + # lets restart the stream and add the cached objects again + logger.warning("Stream closed unexpectedly, restarting...") + self.__reconnect() + # server sets this whenever it restarts, gracefully or unexpectedly, so need to clear it now + self.__is_shutting_down.clear() + with self.__objs_cache_lock: + logger.warning( + f"Re-adding {len(self.__objs_cache)} cached objects to the batch" + ) + self.__batch_objects.prepend(list(self.__objs_cache.values())) + with self.__refs_cache_lock: + self.__batch_references.prepend(list(self.__refs_cache.values())) + # start a new stream with a newly reconnected channel + return recv_wrapper() + + threads = _BgThreads( + loop=threading.Thread( + target=loop_wrapper, + daemon=True, + name="BgBatchLoop", + ), + recv=threading.Thread( + target=recv_wrapper, + daemon=True, + name="BgBatchRecv", + ), + ) + threads.start_recv() + return threads + + def flush(self) -> None: + """Flush the batch queue and wait for all requests to be finished.""" + # bg thread is sending objs+refs automatically, so simply wait for everything to be done + while len(self.__batch_objects) > 0 or len(self.__batch_references) > 0: + time.sleep(0.01) + self.__check_bg_threads_alive() + + def _add_object( + self, + collection: str, + properties: Optional[WeaviateProperties] = None, + references: Optional[ReferenceInputs] = None, + uuid: Optional[UUID] = None, + vector: Optional[VECTORS] = None, + tenant: Optional[str] = None, + ) -> UUID: + self.__check_bg_threads_alive() + try: + batch_object = BatchObject( + collection=collection, + properties=properties, + references=references, + uuid=uuid, + vector=vector, + tenant=tenant, + index=self.__objs_count, + ) + self.__results_for_wrapper.imported_shards.add( + Shard(collection=collection, tenant=tenant) + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + uuid = str(batch_object.uuid) + with self.__uuid_lookup_lock: + self.__uuid_lookup.add(uuid) + self.__batch_objects.add(batch_object) + with self.__objs_cache_lock: + self.__objs_cache[uuid] = batch_object + self.__objs_count += 1 + + # block if queue gets too long or weaviate is overloaded - reading files is faster them sending them so we do + # not need a long queue + while len(self.__inflight_objs) >= self.__batch_size: + self.__check_bg_threads_alive() + time.sleep(0.01) + + assert batch_object.uuid is not None + return batch_object.uuid + + def _add_reference( + self, + from_object_uuid: UUID, + from_object_collection: str, + from_property_name: str, + to: ReferenceInput, + tenant: Optional[str] = None, + ) -> None: + self.__check_bg_threads_alive() + if isinstance(to, ReferenceToMulti): + to_strs: Union[List[str], List[UUID]] = to.uuids_str + elif isinstance(to, str) or isinstance(to, uuid_package.UUID): + to_strs = [to] + else: + to_strs = list(to) + + for uid in to_strs: + try: + batch_reference = BatchReference( + from_object_collection=from_object_collection, + from_object_uuid=from_object_uuid, + from_property_name=from_property_name, + to_object_collection=( + to.target_collection if isinstance(to, ReferenceToMulti) else None + ), + to_object_uuid=uid, + tenant=tenant, + index=self.__refs_count, + ) + except ValidationError as e: + raise WeaviateBatchValidationError(repr(e)) + self.__batch_references.add(batch_reference) + with self.__refs_cache_lock: + self.__refs_cache[batch_reference._to_beacon()] = batch_reference + self.__refs_count += 1 + while len(self.__inflight_refs) >= self.__batch_size * 2: + self.__check_bg_threads_alive() + time.sleep(0.01) + + def __check_bg_threads_alive(self) -> None: + if self.__all_threads_alive(): + return + + raise self.__bg_exception or Exception("Batch thread died unexpectedly") diff --git a/weaviate/collections/collection/async_.py b/weaviate/collections/collection/async_.py index 47ff6d2d1..2c0cea5b0 100644 --- a/weaviate/collections/collection/async_.py +++ b/weaviate/collections/collection/async_.py @@ -5,6 +5,9 @@ from weaviate.cluster import _ClusterAsync from weaviate.collections.aggregate import _AggregateCollectionAsync from weaviate.collections.backups import _CollectionBackupAsync +from weaviate.collections.batch.collection import ( + _BatchCollectionWrapperAsync, +) from weaviate.collections.classes.cluster import Shard from weaviate.collections.classes.config import ConsistencyLevel from weaviate.collections.classes.grpc import METADATA, PROPERTIES, REFERENCES @@ -77,6 +80,15 @@ def __init__( """This namespace includes all the querying methods available to you when using Weaviate's standard aggregation capabilities.""" self.backup: _CollectionBackupAsync = _CollectionBackupAsync(connection, name) """This namespace includes all the backup methods available to you when backing up a collection in Weaviate.""" + self.batch: _BatchCollectionWrapperAsync[Properties] = _BatchCollectionWrapperAsync[ + Properties + ]( + connection, + consistency_level, + name, + tenant, + ) + """This namespace contains all the functionality to upload data in batches to Weaviate for this specific collection.""" self.config = _ConfigCollectionAsync(connection, name, tenant) """This namespace includes all the CRUD methods available to you when modifying the configuration of the collection in Weaviate.""" self.data = _DataCollectionAsync[Properties]( diff --git a/weaviate/collections/collection/sync.py b/weaviate/collections/collection/sync.py index 88f728b30..d50c0c2b5 100644 --- a/weaviate/collections/collection/sync.py +++ b/weaviate/collections/collection/sync.py @@ -7,7 +7,7 @@ from weaviate.collections.backups import _CollectionBackup from weaviate.collections.batch.collection import ( _BatchCollection, - _BatchCollectionNew, + _BatchCollectionSync, _BatchCollectionWrapper, ) from weaviate.collections.classes.cluster import Shard @@ -101,10 +101,8 @@ def __init__( name, tenant, config, - batch_client=_BatchCollectionNew[Properties] - if connection._weaviate_version.is_at_least( - 1, 32, 0 - ) # todo: change to 1.33.0 when it lands + batch_client=_BatchCollectionSync[Properties] + if connection._weaviate_version.is_at_least(1, 36, 0) else _BatchCollection[Properties], ) """This namespace contains all the functionality to upload data in batches to Weaviate for this specific collection.""" diff --git a/weaviate/collections/data/async_.pyi b/weaviate/collections/data/async_.pyi index 28dd4e2e4..15108447a 100644 --- a/weaviate/collections/data/async_.pyi +++ b/weaviate/collections/data/async_.pyi @@ -1,6 +1,10 @@ import uuid as uuid_package from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC +from weaviate.collections.batch.rest import _BatchREST from weaviate.collections.classes.batch import ( BatchObjectReturn, BatchReferenceReturn, @@ -23,6 +27,11 @@ from .executor import _DataCollectionExecutor class _DataCollectionAsync( Generic[Properties,], _DataCollectionExecutor[ConnectionAsync, Properties] ): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + async def insert( self, properties: Properties, diff --git a/weaviate/collections/data/executor.py b/weaviate/collections/data/executor.py index eb63a744d..8d6d12d40 100644 --- a/weaviate/collections/data/executor.py +++ b/weaviate/collections/data/executor.py @@ -19,6 +19,7 @@ from httpx import Response +from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -57,6 +58,11 @@ class _DataCollectionExecutor(Generic[ConnectionType, Properties]): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + def __init__( self, connection: ConnectionType, diff --git a/weaviate/collections/data/sync.pyi b/weaviate/collections/data/sync.pyi index 3fa145a4e..eda3da21a 100644 --- a/weaviate/collections/data/sync.pyi +++ b/weaviate/collections/data/sync.pyi @@ -1,6 +1,10 @@ import uuid as uuid_package from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.grpc_batch import _BatchGRPC +from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC +from weaviate.collections.batch.rest import _BatchREST from weaviate.collections.classes.batch import ( BatchObjectReturn, BatchReferenceReturn, @@ -21,6 +25,11 @@ from weaviate.types import UUID, VECTORS from .executor import _DataCollectionExecutor class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSync, Properties]): + __batch_delete: _BatchDeleteGRPC + __batch_grpc: _BatchGRPC + __batch_rest: _BatchREST + __batch: _BatchCollectionWrapper[Properties] + def insert( self, properties: Properties, diff --git a/weaviate/config.py b/weaviate/config.py index bc0525531..9d2006829 100644 --- a/weaviate/config.py +++ b/weaviate/config.py @@ -56,6 +56,9 @@ class Timeout(BaseModel): query: Union[int, float] = Field(default=30, ge=0) insert: Union[int, float] = Field(default=90, ge=0) init: Union[int, float] = Field(default=2, ge=0) + stream: Union[int, float, None] = Field( + default=None, ge=0, description="Timeout for streaming operations." + ) class Proxies(BaseModel): diff --git a/weaviate/connect/v4.py b/weaviate/connect/v4.py index 0220cc4bf..0848ba6be 100644 --- a/weaviate/connect/v4.py +++ b/weaviate/connect/v4.py @@ -20,13 +20,14 @@ overload, ) +import grpc from authlib.integrations.httpx_client import ( # type: ignore AsyncOAuth2Client, OAuth2Client, ) from grpc import Call, RpcError, StatusCode from grpc import Channel as SyncChannel # type: ignore -from grpc.aio import AioRpcError +from grpc.aio import AioRpcError, StreamStreamCall from grpc.aio import Channel as AsyncChannel # type: ignore # from grpclib.client import Channel @@ -1011,7 +1012,9 @@ def grpc_batch_stream( try: assert self.grpc_stub is not None for msg in self.grpc_stub.BatchStream( - request_iterator=requests, metadata=self.grpc_headers() + request_iterator=requests, + timeout=self.timeout_config.stream, + metadata=self.grpc_headers(), ): yield msg except RpcError as e: @@ -1088,8 +1091,8 @@ def grpc_aggregate( class ConnectionAsync(_ConnectionBase): """Connection class used to communicate to a weaviate instance.""" - async def connect(self) -> None: - if self._connected: + async def connect(self, force: bool = False) -> None: + if self._connected and not force: return None await executor.aresult(self._open_connections_rest(self._auth, "async")) @@ -1221,6 +1224,52 @@ async def grpc_batch_delete( raise InsufficientPermissionsError(e) raise WeaviateDeleteManyError(str(e)) + def grpc_batch_stream( + self, + ) -> StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply]: + assert isinstance(self._grpc_channel, grpc.aio.Channel) + return self._grpc_channel.stream_stream( + "/weaviate.v1.Weaviate/BatchStream", + request_serializer=batch_pb2.BatchStreamRequest.SerializeToString, + response_deserializer=batch_pb2.BatchStreamReply.FromString, + )( + request_iterator=None, + timeout=self.timeout_config.stream, + metadata=self.grpc_headers(), + ) + + async def grpc_batch_stream_write( + self, + stream: StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply], + request: batch_pb2.BatchStreamRequest, + ) -> None: + try: + await stream.write(request) + except AioRpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(str(error.details())) + + async def grpc_batch_stream_read( + self, + stream: StreamStreamCall[batch_pb2.BatchStreamRequest, batch_pb2.BatchStreamReply], + ) -> Optional[batch_pb2.BatchStreamReply]: + try: + msg = await stream.read() + if not isinstance(msg, batch_pb2.BatchStreamReply): + return None + return msg + except AioRpcError as e: + error = cast(Call, e) + if error.code() == StatusCode.PERMISSION_DENIED: + raise InsufficientPermissionsError(error) + if error.code() == StatusCode.ABORTED: + raise _BatchStreamShutdownError() + raise WeaviateBatchStreamError(str(error.details())) + async def grpc_tenants_get( self, request: tenants_pb2.TenantsGetRequest ) -> tenants_pb2.TenantsGetReply: