From 03522641b6b786392016941f8065f079601073fb Mon Sep 17 00:00:00 2001 From: Grain Team Date: Fri, 24 Oct 2025 09:56:37 -0700 Subject: [PATCH] Internal PiperOrigin-RevId: 823568742 --- grain/_src/python/BUILD | 18 + .../dataset/transformations/prefetch.py | 8 + .../dataset/transformations/prefetch_test.py | 71 ++++ grain/_src/python/grain_pool.py | 82 ++++- grain/_src/python/grain_pool_test.py | 56 +++- grain/_src/python/variable_size_queue.py | 168 ++++++++++ grain/_src/python/variable_size_queue_test.py | 310 ++++++++++++++++++ 7 files changed, 699 insertions(+), 14 deletions(-) create mode 100644 grain/_src/python/variable_size_queue.py create mode 100644 grain/_src/python/variable_size_queue_test.py diff --git a/grain/_src/python/BUILD b/grain/_src/python/BUILD index 841ecc25c..ca1411ce4 100644 --- a/grain/_src/python/BUILD +++ b/grain/_src/python/BUILD @@ -224,6 +224,7 @@ py_library( ":options", ":record", ":shared_memory_array", + ":variable_size_queue", "//grain/_src/core:config", "//grain/_src/core:monitoring", "//grain/_src/core:parallel", @@ -245,6 +246,7 @@ py_test( ":grain_pool", ":options", ":record", + ":variable_size_queue", "//grain/_src/core:config", "//grain/_src/core:monitoring", "@abseil-py//absl/flags", @@ -370,3 +372,19 @@ py_library( "@pypi//etils:pkg", ], ) + +py_library( + name = "variable_size_queue", + srcs = ["variable_size_queue.py"], + srcs_version = "PY3", +) + +py_test( + name = "variable_size_queue_test", + srcs = ["variable_size_queue_test.py"], + srcs_version = "PY3", + deps = [ + ":variable_size_queue", + "@abseil-py//absl/testing:absltest", + ], +) diff --git a/grain/_src/python/dataset/transformations/prefetch.py b/grain/_src/python/dataset/transformations/prefetch.py index 1f58925e4..e98fbe640 100644 --- a/grain/_src/python/dataset/transformations/prefetch.py +++ b/grain/_src/python/dataset/transformations/prefetch.py @@ -760,6 +760,14 @@ def __str__(self) -> str: f"multiprocessing_options={self._multiprocessing_options})" ) + def set_per_worker_buffer_size(self, per_worker_buffer_size: int): + if self._raw_iterator is None: + raise ValueError( + "Cannot change per worker buffer size before the iterator has been" + " initialized." + ) + self._raw_iterator.set_per_worker_buffer_size(per_worker_buffer_size) + class ThreadPrefetchIterDataset(dataset.IterDataset[T]): """Iterable dataset that uses a synchronized queue for prefetching. diff --git a/grain/_src/python/dataset/transformations/prefetch_test.py b/grain/_src/python/dataset/transformations/prefetch_test.py index 57c683a75..9aacff8a0 100644 --- a/grain/_src/python/dataset/transformations/prefetch_test.py +++ b/grain/_src/python/dataset/transformations/prefetch_test.py @@ -917,6 +917,77 @@ def map_fn(x): ], ) + def test_set_per_worker_buffer_size_increase(self): + ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset() + mp_options = options.MultiprocessingOptions( + num_workers=1, per_worker_buffer_size=1 + ) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + mp_options, + ) + it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__()) + self.assertEqual(next(it), 1) + time.sleep(1) + self.assertEqual( + it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error + ) + it.set_per_worker_buffer_size(2) + self.assertEqual( + it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error + ) + self.assertEqual(next(it), 2) + self.assertEqual(list(it), list(range(3, 11))) + + def test_set_per_worker_buffer_size_decrease(self): + ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset() + mp_options = options.MultiprocessingOptions( + num_workers=1, per_worker_buffer_size=2 + ) + ds = prefetch.MultiprocessPrefetchIterDataset( + ds, + mp_options, + ) + it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__()) + self.assertEqual(next(it), 1) + time.sleep(1) + self.assertEqual( + it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error + ) + it.set_per_worker_buffer_size(1) + self.assertEqual( + it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error + ) + self.assertEqual(next(it), 2) + self.assertEqual(list(it), list(range(3, 11))) + + def test_set_per_worker_buffer_size_to_trigger_error(self): + def f(x): + if x >= 5: + raise ValueError(f'x={x} is too large') + return x + + ds = ( + dataset.MapDataset.range(10) + .map(f) + .to_iter_dataset( + read_options=options.ReadOptions(prefetch_buffer_size=0) + ) + ) + mp_options = options.MultiprocessingOptions( + num_workers=1, per_worker_buffer_size=1 + ) + it = prefetch.MultiprocessPrefetchIterDataset(ds, mp_options).__iter__() + it = cast(prefetch._MultiprocessPrefetchDatasetIterator, it) + self.assertEqual(next(it), 0) + it.set_per_worker_buffer_size(10) + next(it) + time.sleep(3) + q = it._raw_iterator._reader_queue # pytype: disable=attribute-error + # Prefetching will end once an error is put into the reader queue. The + # elements 2, 3, 4 will be in the queue along with the error for 5. + self.assertEqual(q.qsize(), 4) + class ThreadPrefetchIterDatasetTest(parameterized.TestCase): diff --git a/grain/_src/python/grain_pool.py b/grain/_src/python/grain_pool.py index 0bcf381e3..c6bc26882 100644 --- a/grain/_src/python/grain_pool.py +++ b/grain/_src/python/grain_pool.py @@ -69,6 +69,7 @@ from grain._src.python import multiprocessing_common from grain._src.python import record from grain._src.python import shared_memory_array +from grain._src.python import variable_size_queue from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member @@ -225,7 +226,7 @@ def _worker_loop( *, args_queue: queues.Queue, errors_queue: queues.Queue, - output_queue: queues.Queue, + output_queue: variable_size_queue.VariableSizeMultiprocessingQueue, termination_event: synchronize.Event, start_profiling_event: synchronize.Event, stop_profiling_event: synchronize.Event, @@ -342,6 +343,9 @@ def __init__( options: MultiprocessingOptions, worker_init_fn: Callable[[int, int], None] | None = None, stats_in_queues: tuple[queues.Queue, ...] | None = None, + worker_output_queues: list[ + variable_size_queue.VariableSizeMultiprocessingQueue + ], ): """Initialise a Grain Pool. @@ -362,11 +366,13 @@ def __init__( the total worker count. stats_in_queues: Queue to propagate execution summary from child processes to the parent. + worker_output_queues: list of queues for each worker to output elements + to. """ self.num_processes = options.num_workers logging.info("Grain pool will use %i processes.", self.num_processes) self.worker_args_queues = [] - self.worker_output_queues = [] + self.worker_output_queues = worker_output_queues self.processes = [] # Reader termination should always result in worker termination. However, # worker termination should not shut down the reader: workers are terminated @@ -396,11 +402,10 @@ def __init__( for worker_index in range(self.num_processes): worker_args_queue = ctx.Queue(1) - worker_output_queue = ctx.Queue(options.per_worker_buffer_size) process_kwargs = dict( args_queue=worker_args_queue, errors_queue=self.worker_error_queue, - output_queue=worker_output_queue, + output_queue=self.worker_output_queues[worker_index], stats_out_queue=( self.stats_in_queues[worker_index] if self.stats_in_queues @@ -434,7 +439,6 @@ def __init__( target=_worker_loop, kwargs=process_kwargs, daemon=True ) self.worker_args_queues.append(worker_args_queue) - self.worker_output_queues.append(worker_output_queue) self.processes.append(process) logging.info("Grain pool will start child processes.") @@ -589,6 +593,27 @@ class _GrainPoolProcessingComplete: ] +class _ThreadPoolContainer: + """Container for ThreadPool to allow replacing it.""" + + def __init__(self, processes: int): + self.pool = pool.ThreadPool(processes) + + def apply_async(self, *args, **kwargs): + return self.pool.apply_async(*args, **kwargs) + + def close(self): + self.pool.close() + + def join(self): + self.pool.join() + + def replace_pool(self, num_threads: int): + old_pool = self.pool + self.pool = pool.ThreadPool(num_threads) + old_pool.close() + + def _open_shared_memory_for_leaf(element: Any) -> Any: if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata): element = shared_memory_array.SharedMemoryArray.from_metadata(element) @@ -610,6 +635,9 @@ def _process_elements_in_grain_pool( get_element_producer_fn: GetElementProducerFn, multiprocessing_options: MultiprocessingOptions, reader_queue: queue.Queue[_QueueElement], + worker_output_queues: list[ + variable_size_queue.VariableSizeMultiprocessingQueue + ], thread_pool: pool.ThreadPool, termination_event: threading.Event, start_profiling_event: synchronize.Event | None, @@ -636,6 +664,7 @@ def read_thread_should_stop(): options=multiprocessing_options, worker_init_fn=worker_init_fn, stats_in_queues=stats_in_queues, + worker_output_queues=worker_output_queues, ) as g_pool: for element in g_pool: if read_thread_should_stop(): @@ -714,6 +743,7 @@ def __init__( self._last_worker_index = worker_index_to_start_reading - 1 self._worker_init_fn = worker_init_fn self._reader_queue = None + self._worker_output_queues = None self._reader_thread_pool = None self._termination_event = None self._reader_thread = None @@ -736,15 +766,26 @@ def start_prefetch(self) -> None: self._multiprocessing_options.num_workers * self._multiprocessing_options.per_worker_buffer_size ) - self._reader_queue = queue.Queue(maxsize=max_buffered_elements) - self._reader_thread_pool = pool.ThreadPool(max_buffered_elements) + self._reader_queue = variable_size_queue.VariableSizeQueue( + max_buffered_elements + ) + self._reader_thread_pool = _ThreadPoolContainer(max_buffered_elements) self._termination_event = threading.Event() + ctx = mp.get_context("spawn") + self._worker_output_queues = [] + for _ in range(self._multiprocessing_options.num_workers): + self._worker_output_queues.append( + variable_size_queue.VariableSizeMultiprocessingQueue( + self._multiprocessing_options.per_worker_buffer_size, ctx + ) + ) self._reader_thread = threading.Thread( target=_process_elements_in_grain_pool, kwargs=dict( get_element_producer_fn=self._get_element_producer_fn, multiprocessing_options=self._multiprocessing_options, reader_queue=self._reader_queue, + worker_output_queues=self._worker_output_queues, thread_pool=self._reader_thread_pool, termination_event=self._termination_event, start_profiling_event=self._start_profiling_event, @@ -775,6 +816,7 @@ def stop_prefetch(self) -> None: self._reader_thread_pool = None self._reader_thread = None self._reader_queue = None + self._worker_output_queues = None def __enter__(self): self.start_prefetch() @@ -809,7 +851,7 @@ def __next__(self): "MultiProcessIterator is in an invalid state. Note that" " MultiProcessIterator should be used with a 'with' statement." ) - element = multiprocessing_common.get_element_from_queue( + element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error ) if isinstance(element, Exception): @@ -826,9 +868,31 @@ def __next__(self): ) result = multiprocessing_common.get_async_result( - element.async_result, self._termination_event.is_set + element.async_result, self._termination_event.is_set # pytype: disable=attribute-error ) if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access raise StopIteration self._last_worker_index = element.worker_index return result + + def set_per_worker_buffer_size(self, per_worker_buffer_size: int): + """Sets the per worker buffer size.""" + if self._worker_output_queues is None or self._reader_queue is None: + raise ValueError( + "Cannot change per worker buffer size before the iterator has been" + " initialized." + ) + for q in self._worker_output_queues: + q.set_max_size(per_worker_buffer_size) + self._reader_queue.set_max_size( + per_worker_buffer_size * self._multiprocessing_options.num_workers + ) + self._multiprocessing_options = dataclasses.replace( + self._multiprocessing_options, + per_worker_buffer_size=per_worker_buffer_size, + ) + new_thread_count = ( + self._multiprocessing_options.num_workers + * self._multiprocessing_options.per_worker_buffer_size + ) + self._reader_thread_pool.replace_pool(new_thread_count) # pytype: disable=attribute-error diff --git a/grain/_src/python/grain_pool_test.py b/grain/_src/python/grain_pool_test.py index 5aa87795b..5efc5df01 100644 --- a/grain/_src/python/grain_pool_test.py +++ b/grain/_src/python/grain_pool_test.py @@ -29,6 +29,7 @@ from grain._src.python import data_sources from grain._src.python import grain_pool as gp from grain._src.python import record +from grain._src.python import variable_size_queue from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member @@ -50,11 +51,19 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() # unparse the flags explicitly flags.FLAGS.unparse_flags() - + ctx = mp.get_context("spawn") + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] _ = gp.GrainPool( - ctx=mp.get_context("spawn"), + ctx=ctx, get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), + options=options, + worker_output_queues=worker_output_queues, ) def test_pool_equal_split_in_memory_data_source(self): @@ -71,10 +80,19 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() output_elements = [] + ctx = mp.get_context("spawn") + options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] with gp.GrainPool( - ctx=mp.get_context("spawn"), + ctx=ctx, get_element_producer_fn=get_element_producer_fn, - options=MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1), + options=options, + worker_output_queues=worker_output_queues, ) as grain_pool: for element in grain_pool: output_elements.append(element) @@ -100,11 +118,18 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] output_elements = [] with gp.GrainPool( ctx=ctx, get_element_producer_fn=get_element_producer_fn, options=options, + worker_output_queues=worker_output_queues, ) as grain_pool: for element in grain_pool: output_elements.append(element) @@ -133,11 +158,18 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] output_elements = [] with gp.GrainPool( ctx=ctx, get_element_producer_fn=get_element_producer_fn, options=options, + worker_output_queues=worker_output_queues, ) as grain_pool: for element in grain_pool: output_elements.append(element) @@ -166,10 +198,17 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] with gp.GrainPool( ctx=ctx, get_element_producer_fn=get_element_producer_fn, options=options, + worker_output_queues=worker_output_queues, ) as grain_pool: child_pid = grain_pool.processes[0].pid os.kill(child_pid, signal.SIGKILL) @@ -192,6 +231,12 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): get_element_producer_fn = GetElementProducerFn() options = MultiprocessingOptions(num_workers=4, per_worker_buffer_size=1) + worker_output_queues = [ + variable_size_queue.VariableSizeMultiprocessingQueue( + options.per_worker_buffer_size, ctx + ) + for _ in range(options.num_workers) + ] # Users should generally use the with statement, here we test if GrainPool # was created without the "with statement", that object deletion would @@ -200,6 +245,7 @@ def __call__(self, *, worker_index: int, worker_count: int, **kwargs): ctx=ctx, get_element_producer_fn=get_element_producer_fn, options=options, + worker_output_queues=worker_output_queues, ) child_processes = grain_pool.processes diff --git a/grain/_src/python/variable_size_queue.py b/grain/_src/python/variable_size_queue.py new file mode 100644 index 000000000..9f9a07201 --- /dev/null +++ b/grain/_src/python/variable_size_queue.py @@ -0,0 +1,168 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""This module provides variable size queue implementations.""" + +from multiprocessing import context +from multiprocessing import queues +from multiprocessing import reduction +import queue +import threading +import time +from typing import Any, cast + + +class VariableSizeMultiprocessingQueue(queues.Queue): + """A multiprocessing queue whose max size can be dynamically changed.""" + + def __init__( + self, + max_size: int, + ctx: context.BaseContext, + ): + super().__init__(maxsize=max_size, ctx=ctx) + self._max_size_val = ctx.Value("i", max_size, lock=False) + self._sem = ctx.Semaphore(max_size) + self._resize_lock = ctx.Lock() + self._pending_shrink = ctx.Value("i", 0, lock=False) + + def __getstate__(self): + return cast(tuple[Any, ...], super().__getstate__()) + ( + self._resize_lock, + self._max_size_val, + self._pending_shrink, + ) + + def __setstate__(self, state): + super().__setstate__(state[:-3]) # pytype: disable=attribute-error + self._resize_lock, self._max_size_val, self._pending_shrink = state[-3:] + + def set_max_size(self, max_size: int): + """Sets the maximum size of the queue. + + This method can be used to dynamically change the capacity of the queue. + If `max_size` is greater than the current size, the queue capacity is + increased immediately. If `max_size` is smaller, the queue will prevent + new items from being added until the number of items in the queue drops + below the new `max_size`. + + Args: + max_size: The new maximum size for the queue. + """ + with self._resize_lock: + delta = max_size - self._max_size_val.value + if delta > 0: + for _ in range(delta): + self._sem.release() + self._max_size_val.value = max_size + elif delta < 0: + self._max_size_val.value = max_size + self._pending_shrink.value -= delta + # Try to shrink capacity eagerly, but don't block. + for _ in range(-delta): + if self._sem.acquire(block=False): + self._pending_shrink.value -= 1 + else: + break + + def get(self, block: bool = True, timeout: float | None = None): + """Gets an item from the queue, similar to `queue.Queue.get`.""" + if self._closed: # pytype: disable=attribute-error + raise ValueError(f"Queue {self!r} is closed") + if block and timeout is None: + with self._rlock: # pytype: disable=attribute-error + res = self._recv_bytes() # pytype: disable=attribute-error + else: + if block: + deadline = time.time() + timeout + if not self._rlock.acquire(block, timeout): # pytype: disable=attribute-error + raise queue.Empty + try: + if block: + timeout = deadline - time.time() + if not self._poll(timeout): # pytype: disable=attribute-error + raise queue.Empty + elif not self._poll(): # pytype: disable=attribute-error + raise queue.Empty + res = self._recv_bytes() # pytype: disable=attribute-error + finally: + self._rlock.release() # pytype: disable=attribute-error + with self._resize_lock: + if self._pending_shrink.value > 0: + self._pending_shrink.value -= 1 + else: + self._sem.release() + return reduction.ForkingPickler.loads(res) + + +class VariableSizeQueue(queue.Queue): + """A queue whose max size can be dynamically changed.""" + + def __init__(self, max_size: int): + super().__init__(maxsize=0) + self._max_size = max_size + self._cond = threading.Condition() + + def set_max_size(self, max_size: int): + with self._cond: + self._max_size = max_size + self._cond.notify_all() + + def put(self, item, block: bool = True, timeout: float | None = None): + """Puts an item into the queue, similar to `queue.Queue.put`. + + This method behaves like `queue.Queue.put`, but respects the current + `_max_size` of this variable-size queue. If the queue is full based on + `_max_size`, this method can block or raise `queue.Full` depending on + `block` and `timeout`. + + Args: + item: The object to put into the queue. + block: If True, block until a free slot is available. + timeout: If `block` is True, wait for at most `timeout` seconds. + + Raises: + queue.Full: If the queue is full and `block` is False or the `timeout` + is reached. + """ + if not block: + with self._cond: + if self.qsize() >= self._max_size: + raise queue.Full + super().put(item, block=False) + return + + deadline = None + if timeout is not None: + deadline = time.time() + timeout + + with self._cond: + while self.qsize() >= self._max_size: + if deadline is None: + self._cond.wait() + continue + remaining = deadline - time.time() + if remaining <= 0: + raise queue.Full + if not self._cond.wait(remaining): + if self.qsize() >= self._max_size: + raise queue.Full + else: + break + super().put(item, block=False) + + def get(self, block: bool = True, timeout: float | None = None): + item = super().get(block=block, timeout=timeout) + with self._cond: + self._cond.notify() + return item diff --git a/grain/_src/python/variable_size_queue_test.py b/grain/_src/python/variable_size_queue_test.py new file mode 100644 index 000000000..fa4e0767d --- /dev/null +++ b/grain/_src/python/variable_size_queue_test.py @@ -0,0 +1,310 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for variable size queue implementations.""" + +import queue +import threading +import time + +from absl.testing import absltest +import multiprocessing as mp +from grain._src.python import variable_size_queue + + +def _consumer_function_for_test(q, result): + time.sleep(0.1) + result.append(q.get()) + + +def _increase_max_size_function_for_test(q): + time.sleep(0.1) + q.set_max_size(2) + + +class VariableSizeQueueTest(absltest.TestCase): + + def test_put_and_get(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + self.assertEqual(q.qsize(), 0) + q.put(1) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 0) + + def test_put_non_blocking_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + + def test_put_blocking_with_timeout_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=True, timeout=0.1) + + def test_set_max_size_to_increase_capacity(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + q.set_max_size(2) + q.put(2) # Should not raise. + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + + def test_set_max_size_to_decrease_capacity(self): + q = variable_size_queue.VariableSizeQueue(max_size=2) + q.put(1) + q.put(2) + self.assertEqual(q.qsize(), 2) + q.set_max_size(1) + # qsize is 2, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 1) + self.assertEqual(q.qsize(), 1) + # qsize is 1, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 2) + self.assertEqual(q.qsize(), 0) + # qsize is 0, max_size is 1. put should succeed. + q.put(3) + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 3) + + def test_put_blocks_until_item_is_retrieved(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + result = [] + + def consumer(): + time.sleep(0.1) + result.append(q.get()) + + t = threading.Thread(target=consumer) + t.start() + q.put(2) # This should block until consumer gets item 1. + self.assertEqual(q.qsize(), 1) + self.assertEqual(q.get(), 2) + t.join() + self.assertEqual(result, [1]) + + def test_put_blocks_until_max_size_increases(self): + q = variable_size_queue.VariableSizeQueue(max_size=1) + q.put(1) + + def increase_max_size(): + time.sleep(0.1) + q.set_max_size(2) + + t = threading.Thread(target=increase_max_size) + t.start() + q.put(2) # This should block until max_size is increased. + self.assertEqual(q.qsize(), 2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + t.join() + + def test_set_max_size_to_decrease_capacity_blocks_put(self): + q = variable_size_queue.VariableSizeQueue(max_size=2) + q.put(1) + q.put(2) + q.set_max_size(1) + + put_event = threading.Event() + + def _blocking_put(): + q.put(3) + put_event.set() + + t = threading.Thread(target=_blocking_put) + t.start() + + # The queue size is 2, max_size is 1. The put(3) call should block. + # We wait a bit to ensure the thread has started and blocked on put(). + time.sleep(0.2) + self.assertFalse(put_event.is_set()) + + # Get one item. qsize becomes 1, which equals max_size. + # However, because of _pending_shrink, no capacity is released, + # so put(3) should still be blocked. + self.assertEqual(q.get(), 1) + time.sleep(0.2) + self.assertFalse(put_event.is_set()) + + # Get another item. qsize becomes 0. + # This time, capacity should be released, unblocking put(3). + self.assertEqual(q.get(), 2) + self.assertTrue(put_event.wait(timeout=1)) + self.assertEqual(q.get(), 3) + t.join() + + +class VariableSizeMultiprocessingQueueTest(absltest.TestCase): + + def test_put_and_get(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + q.put(1) + self.assertEqual(q.get(), 1) + + def test_put_non_blocking_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + + def test_put_blocking_with_timeout_to_full_queue_raises_full(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=True, timeout=0.1) + + def test_set_max_size_to_increase_capacity(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + q.put(1) + with self.assertRaises(queue.Full): + q.put(2, block=False) + q.set_max_size(2) + q.put(2, block=False) # Should not raise. + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + + def test_set_max_size_to_decrease_capacity(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 2, ctx=mp.get_context("spawn") + ) + q.put(1) + q.put(2) + q.set_max_size(1) + # qsize is 2, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 1) + # qsize is 1, max_size is 1. put should fail. + with self.assertRaises(queue.Full): + q.put(3, block=False) + self.assertEqual(q.get(), 2) + # qsize is 0, max_size is 1. put should succeed. + q.put(3, block=False) + self.assertEqual(q.get(), 3) + + def test_set_max_size_to_decrease_capacity_blocks_put(self): + ctx = mp.get_context("spawn") + q = variable_size_queue.VariableSizeMultiprocessingQueue(2, ctx=ctx) + q.put(1) + q.put(2) + q.set_max_size(1) + + put_event = threading.Event() + + def _blocking_put(): + q.put(3) + put_event.set() + + t = threading.Thread(target=_blocking_put) + t.start() + + # The queue has 2 items, max_size is 1. The put(3) call should block. + # We wait a bit to ensure the thread has started and blocked on put(). + time.sleep(0.2) + self.assertFalse(put_event.is_set()) + + # Get one item. conceptual size becomes 1, which equals max_size. + # However, because of _pending_shrink, no capacity is released, + # so put(3) should still be blocked. + self.assertEqual(q.get(), 1) + time.sleep(0.2) + self.assertFalse(put_event.is_set()) + + # Get another item. conceptual size becomes 0. + # This time, capacity should be released, unblocking put(3). + self.assertEqual(q.get(), 2) + self.assertTrue(put_event.wait(timeout=1)) + self.assertEqual(q.get(), 3) + t.join() + + def test_put_blocks_until_item_is_retrieved_from_process(self): + ctx = mp.get_context("spawn") + q = variable_size_queue.VariableSizeMultiprocessingQueue(1, ctx=ctx) + q.put(1) + + with ctx.Manager() as manager: + result_list = manager.list() + + p = ctx.Process(target=_consumer_function_for_test, args=(q, result_list)) + p.start() + q.put(2) # This should block until consumer gets item 1. + self.assertEqual(q.get(), 2) + p.join() + self.assertEqual(list(result_list), [1]) + + def test_put_blocks_until_max_size_increases_from_process(self): + ctx = mp.get_context("spawn") + q = variable_size_queue.VariableSizeMultiprocessingQueue(1, ctx=ctx) + q.put(1) + + p = ctx.Process(target=_increase_max_size_function_for_test, args=(q,)) + p.start() + # This should block until max_size is increased in the other process. + q.put(2) + self.assertEqual(q.get(), 1) + self.assertEqual(q.get(), 2) + p.join() + + def test_empty(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + self.assertTrue(q.empty()) + q.put(1, block=False) + while q.empty(): + time.sleep(0.1) + self.assertFalse(q.empty()) + q.get() + self.assertTrue(q.empty()) + + def test_get_nowait(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + with self.assertRaises(queue.Empty): + q.get_nowait() + q.put(1) + while q.empty(): + time.sleep(0.1) + self.assertEqual(q.get_nowait(), 1) + with self.assertRaises(queue.Empty): + q.get_nowait() + + def test_close_and_cancel_join_thread(self): + q = variable_size_queue.VariableSizeMultiprocessingQueue( + 1, ctx=mp.get_context("spawn") + ) + q.close() + q.cancel_join_thread() + + +if __name__ == "__main__": + absltest.main()