diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 0d3e024f8..844d54424 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -27,7 +27,7 @@ env: WEAVIATE_133: 1.33.11 WEAVIATE_134: 1.34.8 WEAVIATE_135: 1.35.2 - WEAVIATE_136: 1.36.0-dev-c8f578d + WEAVIATE_136: 1.36.0-dev-0bbf31a jobs: lint-and-format: diff --git a/integration/test_batch_v4.py b/integration/test_batch_v4.py index ac8350b50..42a904986 100644 --- a/integration/test_batch_v4.py +++ b/integration/test_batch_v4.py @@ -433,7 +433,7 @@ def test_add_ten_thousand_data_objects( """Test adding ten thousand data objects.""" client, name = client_factory() if ( - request.node.callspec.id == "test_add_ten_thousand_data_objects_experimental" + request.node.callspec.id == "test_add_ten_thousand_data_objects_stream" and client._connection._weaviate_version.is_lower_than(1, 36, 0) ): pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") @@ -641,7 +641,7 @@ def test_add_one_object_and_a_self_reference( """Test adding one object and a self reference.""" client, name = client_factory() if ( - request.node.callspec.id == "test_add_one_object_and_a_self_reference_experimental" + request.node.callspec.id == "test_add_one_object_and_a_self_reference_stream" and client._connection._weaviate_version.is_lower_than(1, 36, 0) ): pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") diff --git a/integration/test_collection_batch.py b/integration/test_collection_batch.py index e89bfba9f..40683a26f 100644 --- a/integration/test_collection_batch.py +++ b/integration/test_collection_batch.py @@ -271,7 +271,7 @@ def test_non_existant_collection(collection_factory_get: CollectionFactoryGet) - @pytest.mark.asyncio -async def test_add_one_hundred_thousand_objects_async_collection( +async def test_batch_one_hundred_thousand_objects_async_collection( batch_collection_async: BatchCollectionAsync, ) -> None: """Test adding one hundred thousand data objects.""" @@ -295,3 +295,46 @@ async def test_add_one_hundred_thousand_objects_async_collection( assert await col.length() == nr_objects assert col.batch.results.objs.has_errors is False assert len(col.batch.failed_objects) == 0, [obj.message for obj in col.batch.failed_objects] + + +@pytest.mark.asyncio +async def test_ingest_one_hundred_thousand_data_objects_async( + batch_collection_async: BatchCollectionAsync, +) -> None: + col = await batch_collection_async() + if col._connection._weaviate_version.is_lower_than(1, 36, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") + nr_objects = 100000 + import time + + start = time.time() + results = await col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects)) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(results.errors) == 0 + assert len(results.all_responses) == nr_objects + assert len(results.uuids) == nr_objects + assert await col.length() == nr_objects + assert results.has_errors is False + assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] + + +def test_ingest_one_hundred_thousand_data_objects( + batch_collection: BatchCollection, +) -> None: + col = batch_collection() + if col._connection._weaviate_version.is_lower_than(1, 36, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") + nr_objects = 100000 + import time + + start = time.time() + results = col.data.ingest({"name": "test" + str(i)} for i in range(nr_objects)) + end = time.time() + print(f"Time taken to add {nr_objects} objects: {end - start} seconds") + assert len(results.errors) == 0 + assert len(results.all_responses) == nr_objects + assert len(results.uuids) == nr_objects + assert len(col) == nr_objects + assert results.has_errors is False + assert len(results.errors) == 0, [obj.message for obj in results.errors.values()] diff --git a/integration/test_rbac.py b/integration/test_rbac.py index 9a8b3c6d7..93c930672 100644 --- a/integration/test_rbac.py +++ b/integration/test_rbac.py @@ -742,8 +742,8 @@ def test_server_side_batching_with_auth() -> None: with connect_to_local( port=RBAC_PORTS[0], grpc_port=RBAC_PORTS[1], auth_credentials=RBAC_AUTH_CREDS ) as client: - if client._connection._weaviate_version.is_lower_than(1, 34, 0): - pytest.skip("Server-side batching not supported in Weaviate < 1.34.0") + if client._connection._weaviate_version.is_lower_than(1, 36, 0): + pytest.skip("Server-side batching not supported in Weaviate < 1.36.0") collection = client.collections.create(collection_name) with client.batch.stream() as batch: batch.add_object(collection_name) diff --git a/weaviate/collections/batch/base.py b/weaviate/collections/batch/base.py index 4a87a0acf..14a0c0768 100644 --- a/weaviate/collections/batch/base.py +++ b/weaviate/collections/batch/base.py @@ -359,6 +359,9 @@ def number_errors(self) -> int: def _start(self): pass + def _wait(self): + pass + def _shutdown(self) -> None: """Shutdown the current batch and wait for all requests to be finished.""" self.flush() diff --git a/weaviate/collections/batch/batch_wrapper.py b/weaviate/collections/batch/batch_wrapper.py index 3c1acc827..a3a3598d6 100644 --- a/weaviate/collections/batch/batch_wrapper.py +++ b/weaviate/collections/batch/batch_wrapper.py @@ -10,7 +10,6 @@ _ClusterBatch, _ClusterBatchAsync, _DynamicBatching, - _ServerSideBatching, ) from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.batch import ( @@ -140,8 +139,6 @@ def __init__( self._connection = connection self._consistency_level = consistency_level self._current_batch: Optional[_BatchBaseAsync] = None - # config options - self._batch_mode: _BatchMode = _ServerSideBatching(1) self._batch_data = _BatchDataWrapper() self._cluster = _ClusterBatchAsync(connection) @@ -371,7 +368,7 @@ async def add_reference( """ ... - async def flush(self) -> None: + def flush(self) -> None: """Flush the current batch. This will send all the objects and references in the current batch to Weaviate. @@ -505,19 +502,20 @@ def number_errors(self) -> int: Q = TypeVar("Q", bound=Union[BatchClientProtocolAsync, BatchCollectionProtocolAsync[Properties]]) -class _ContextManagerWrapper(Generic[T, P]): +class _ContextManagerSync(Generic[T, P]): def __init__(self, current_batch: T): self.__current_batch: T = current_batch def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.__current_batch._shutdown() + self.__current_batch._wait() def __enter__(self) -> P: self.__current_batch._start() return self.__current_batch # pyright: ignore[reportReturnType] -class _ContextManagerWrapperAsync(Generic[Q]): +class _ContextManagerAsync(Generic[Q]): def __init__(self, current_batch: _BatchBaseAsync): self.__current_batch = current_batch diff --git a/weaviate/collections/batch/client.py b/weaviate/collections/batch/client.py index 789236184..d28834c66 100644 --- a/weaviate/collections/batch/client.py +++ b/weaviate/collections/batch/client.py @@ -19,8 +19,8 @@ _BatchMode, _BatchWrapper, _BatchWrapperAsync, - _ContextManagerWrapper, - _ContextManagerWrapperAsync, + _ContextManagerAsync, + _ContextManagerSync, ) from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers @@ -146,10 +146,10 @@ async def add_reference( BatchClient = _BatchClient BatchClientSync = _BatchClientSync BatchClientAsync = _BatchClientAsync -ClientBatchingContextManager = _ContextManagerWrapper[ +ClientBatchingContextManager = _ContextManagerSync[ Union[BatchClient, BatchClientSync], BatchClientProtocol ] -AsyncClientBatchingContextManager = _ContextManagerWrapperAsync[BatchClientProtocolAsync] +ClientBatchingContextManagerAsync = _ContextManagerAsync[BatchClientProtocolAsync] class _BatchClientWrapper(_BatchWrapper): @@ -196,7 +196,7 @@ def __create_batch_and_reset( self._batch_data = _BatchDataWrapper() # clear old data - return _ContextManagerWrapper( + return _ContextManagerSync( batch_client( connection=self._connection, consistency_level=self._consistency_level, @@ -310,7 +310,7 @@ def __init__( def __create_batch_and_reset(self): self._batch_data = _BatchDataWrapper() # clear old data - return _ContextManagerWrapperAsync( + return _ContextManagerAsync( BatchClientAsync( connection=self._connection, consistency_level=self._consistency_level, @@ -328,7 +328,7 @@ def experimental( *, concurrency: Optional[int] = None, consistency_level: Optional[ConsistencyLevel] = None, - ) -> AsyncClientBatchingContextManager: + ) -> ClientBatchingContextManagerAsync: return self.stream(concurrency=concurrency, consistency_level=consistency_level) def stream( @@ -336,7 +336,7 @@ def stream( *, concurrency: Optional[int] = None, consistency_level: Optional[ConsistencyLevel] = None, - ) -> AsyncClientBatchingContextManager: + ) -> ClientBatchingContextManagerAsync: """Configure the batching context manager to use batch streaming. When you exit the context manager, the final batch will be sent automatically. @@ -345,9 +345,9 @@ def stream( concurrency: The number of concurrent streams to use when sending batches. If not provided, the default will be one. consistency_level: The consistency level to be used when inserting data. If not provided, the default value is `None`. """ - if self._connection._weaviate_version.is_lower_than(1, 34, 0): + if self._connection._weaviate_version.is_lower_than(1, 36, 0): raise WeaviateUnsupportedFeatureError( - "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + "Server-side batching", str(self._connection._weaviate_version), "1.36.0" ) self._batch_mode = _ServerSideBatching( # concurrency=concurrency diff --git a/weaviate/collections/batch/collection.py b/weaviate/collections/batch/collection.py index c7e701a34..7889db335 100644 --- a/weaviate/collections/batch/collection.py +++ b/weaviate/collections/batch/collection.py @@ -19,8 +19,8 @@ BatchCollectionProtocolAsync, _BatchWrapper, _BatchWrapperAsync, - _ContextManagerWrapper, - _ContextManagerWrapperAsync, + _ContextManagerAsync, + _ContextManagerSync, ) from weaviate.collections.batch.sync import _BatchBaseSync from weaviate.collections.classes.config import ConsistencyLevel, Vectorizers @@ -88,14 +88,14 @@ def add_reference( class _BatchCollectionSync(Generic[Properties], _BatchBaseSync): def __init__( self, - executor: ThreadPoolExecutor, connection: ConnectionSync, consistency_level: Optional[ConsistencyLevel], results: _BatchDataWrapper, - batch_mode: _BatchMode, name: str, tenant: Optional[str], - vectorizer_batching: bool, + executor: Optional[ThreadPoolExecutor] = None, + batch_mode: Optional[_BatchMode] = None, + vectorizer_batching: bool = False, ) -> None: super().__init__( connection=connection, @@ -184,11 +184,11 @@ async def add_reference( BatchCollection = _BatchCollection BatchCollectionSync = _BatchCollectionSync BatchCollectionAsync = _BatchCollectionAsync -CollectionBatchingContextManager = _ContextManagerWrapper[ +CollectionBatchingContextManager = _ContextManagerSync[ Union[BatchCollection[Properties], BatchCollectionSync[Properties]], BatchCollectionProtocol[Properties], ] -CollectionBatchingContextManagerAsync = _ContextManagerWrapperAsync[ +CollectionBatchingContextManagerAsync = _ContextManagerAsync[ BatchCollectionProtocolAsync[Properties] ] @@ -239,7 +239,7 @@ def __create_batch_and_reset( self._vectorizer_batching = False self._batch_data = _BatchDataWrapper() # clear old data - return _ContextManagerWrapper( + return _ContextManagerSync( batch_client( connection=self._connection, consistency_level=self._consistency_level, @@ -311,9 +311,9 @@ def stream( concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests made to Weaviate. If not provided, the default value is 1. """ - if self._connection._weaviate_version.is_lower_than(1, 34, 0): + if self._connection._weaviate_version.is_lower_than(1, 36, 0): raise WeaviateUnsupportedFeatureError( - "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + "Server-side batching", str(self._connection._weaviate_version), "1.36.0" ) self._batch_mode = _ServerSideBatching( # concurrency=concurrency @@ -338,7 +338,7 @@ def __init__( def __create_batch_and_reset(self): self._batch_data = _BatchDataWrapper() # clear old data - return _ContextManagerWrapperAsync( + return _ContextManagerAsync( BatchCollectionAsync( connection=self._connection, consistency_level=self._consistency_level, @@ -371,9 +371,9 @@ def stream( concurrency: The number of concurrent requests when sending batches. This controls the number of concurrent requests made to Weaviate. If not provided, the default value is 1. """ - if self._connection._weaviate_version.is_lower_than(1, 34, 0): + if self._connection._weaviate_version.is_lower_than(1, 36, 0): raise WeaviateUnsupportedFeatureError( - "Server-side batching", str(self._connection._weaviate_version), "1.34.0" + "Server-side batching", str(self._connection._weaviate_version), "1.36.0" ) self._batch_mode = _ServerSideBatching( # concurrency=concurrency diff --git a/weaviate/collections/batch/sync.py b/weaviate/collections/batch/sync.py index 0c59c92c4..89aff0085 100644 --- a/weaviate/collections/batch/sync.py +++ b/weaviate/collections/batch/sync.py @@ -15,7 +15,6 @@ _BatchMode, _BgThreads, _ClusterBatch, - _ServerSideBatching, ) from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.classes.batch import ( @@ -54,9 +53,9 @@ def __init__( connection: ConnectionSync, consistency_level: Optional[ConsistencyLevel], results: _BatchDataWrapper, - batch_mode: _BatchMode, - executor: ThreadPoolExecutor, - vectorizer_batching: bool, + batch_mode: Optional[_BatchMode] = None, + executor: Optional[ThreadPoolExecutor] = None, + vectorizer_batching: bool = False, objects: Optional[ObjectsBatchRequest[BatchObject]] = None, references: Optional[ReferencesBatchRequest[BatchReference]] = None, ) -> None: @@ -108,8 +107,6 @@ def __init__( self.__reqs: Queue[Optional[batch_pb2.BatchStreamRequest]] = Queue(maxsize=1) self.__stop = False - self.__batch_mode = batch_mode - @property def number_errors(self) -> int: """Return the number of errors in the batch.""" @@ -123,12 +120,7 @@ def __all_threads_alive(self) -> bool: ) def _start(self) -> None: - assert isinstance(self.__batch_mode, _ServerSideBatching), ( - "Only server-side batching is supported in this mode" - ) - self.__bg_threads = [ - self.__start_bg_threads() for _ in range(self.__batch_mode.concurrency) - ] + self.__bg_threads = [self.__start_bg_threads() for _ in range(1)] logger.info( f"Provisioned {len(self.__bg_threads)} stream(s) to the server for batch processing" ) diff --git a/weaviate/collections/data/async_.pyi b/weaviate/collections/data/async_.pyi index 15108447a..8dd092bf3 100644 --- a/weaviate/collections/data/async_.pyi +++ b/weaviate/collections/data/async_.pyi @@ -1,7 +1,6 @@ import uuid as uuid_package -from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload -from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -30,7 +29,6 @@ class _DataCollectionAsync( __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] async def insert( self, @@ -81,3 +79,6 @@ class _DataCollectionAsync( async def delete_many( self, where: _Filters, *, verbose: bool = False, dry_run: bool = False ) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ... + async def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> BatchObjectReturn: ... diff --git a/weaviate/collections/data/executor.py b/weaviate/collections/data/executor.py index 8d6d12d40..73a07b1f9 100644 --- a/weaviate/collections/data/executor.py +++ b/weaviate/collections/data/executor.py @@ -5,6 +5,7 @@ Any, Dict, Generic, + Iterable, List, Literal, Mapping, @@ -19,7 +20,13 @@ from httpx import Response -from weaviate.collections.batch.collection import _BatchCollectionWrapper +from weaviate.collections.batch.base import _BatchDataWrapper +from weaviate.collections.batch.collection import ( + BatchCollectionAsync, + BatchCollectionSync, + CollectionBatchingContextManager, + CollectionBatchingContextManagerAsync, +) from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -61,7 +68,6 @@ class _DataCollectionExecutor(Generic[ConnectionType, Properties]): __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] def __init__( self, @@ -704,3 +710,78 @@ def __parse_vector(self, obj: Dict[str, Any], vector: VECTORS) -> Dict[str, Any] else: obj["vector"] = _get_vector_v4(vector) return obj + + def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> executor.Result[BatchObjectReturn]: + """Ingest multiple objects into the collection in batches. The batching is handled automatically for you by Weaviate. + + This is different from `insert_many` which sends all objects in a single batch request. Use this method when you want to insert a large number of objects without worrying about batch sizes + and whether they will fit into the maximum allowed batch size of your Weaviate instance. In addition, use this instead of `client.batch.dynamic()` or `collection.batch.dynamic()` for a more + performant dynamic batching algorithm that utilizes server-side batching. + + Args: + objs: An iterable of objects to insert. This can be either a sequence of `Properties` or `DataObject[Properties, ReferenceInputs]` + If you didn't set `data_model` then `Properties` will be `Data[str, Any]` in which case you can insert simple dictionaries here. + """ + if isinstance(self._connection, ConnectionAsync): + con = self._connection + + async def execute() -> BatchObjectReturn: + results = _BatchDataWrapper() + ctx = CollectionBatchingContextManagerAsync( + BatchCollectionAsync( + connection=con, + results=results, + consistency_level=self._consistency_level, + name=self.name, + tenant=self._tenant, + ) + ) + async with ctx as batch: + for obj in objs: + if isinstance(obj, DataObject): + await batch.add_object( + properties=cast(dict, obj.properties), + references=obj.references, + uuid=obj.uuid, + vector=obj.vector, + ) + else: + await batch.add_object( + properties=cast(dict, obj), + references=None, + uuid=None, + vector=None, + ) + return results.results.objs + + return execute() + + results = _BatchDataWrapper() + ctx = CollectionBatchingContextManager( + BatchCollectionSync( + connection=self._connection, + results=results, + consistency_level=self._consistency_level, + name=self.name, + tenant=self._tenant, + ) + ) + with ctx as batch: + for obj in objs: + if isinstance(obj, DataObject): + batch.add_object( + properties=cast(dict, obj.properties), + references=obj.references, + uuid=obj.uuid, + vector=obj.vector, + ) + else: + batch.add_object( + properties=cast(dict, obj), + references=None, + uuid=None, + vector=None, + ) + return results.results.objs diff --git a/weaviate/collections/data/sync.pyi b/weaviate/collections/data/sync.pyi index eda3da21a..ab1eb3f39 100644 --- a/weaviate/collections/data/sync.pyi +++ b/weaviate/collections/data/sync.pyi @@ -1,7 +1,6 @@ import uuid as uuid_package -from typing import Generic, List, Literal, Optional, Sequence, Union, overload +from typing import Generic, Iterable, List, Literal, Optional, Sequence, Union, overload -from weaviate.collections.batch.collection import _BatchCollectionWrapper from weaviate.collections.batch.grpc_batch import _BatchGRPC from weaviate.collections.batch.grpc_batch_delete import _BatchDeleteGRPC from weaviate.collections.batch.rest import _BatchREST @@ -28,7 +27,6 @@ class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSy __batch_delete: _BatchDeleteGRPC __batch_grpc: _BatchGRPC __batch_rest: _BatchREST - __batch: _BatchCollectionWrapper[Properties] def insert( self, @@ -79,3 +77,6 @@ class _DataCollection(Generic[Properties,], _DataCollectionExecutor[ConnectionSy def delete_many( self, where: _Filters, *, verbose: bool = False, dry_run: bool = False ) -> Union[DeleteManyReturn[List[DeleteManyObject]], DeleteManyReturn[None]]: ... + def ingest( + self, objs: Iterable[Union[Properties, DataObject[Properties, Optional[ReferenceInputs]]]] + ) -> BatchObjectReturn: ...