Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions grain/_src/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ py_library(
":options",
":record",
":shared_memory_array",
":variable_size_queue",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"//grain/_src/core:parallel",
Expand All @@ -245,6 +246,7 @@ py_test(
":grain_pool",
":options",
":record",
":variable_size_queue",
"//grain/_src/core:config",
"//grain/_src/core:monitoring",
"@abseil-py//absl/flags",
Expand Down Expand Up @@ -370,3 +372,19 @@ py_library(
"@pypi//etils:pkg",
],
)

py_library(
name = "variable_size_queue",
srcs = ["variable_size_queue.py"],
srcs_version = "PY3",
)

py_test(
name = "variable_size_queue_test",
srcs = ["variable_size_queue_test.py"],
srcs_version = "PY3",
deps = [
":variable_size_queue",
"@abseil-py//absl/testing:absltest",
],
)
8 changes: 8 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch.py
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,14 @@ def __str__(self) -> str:
f"multiprocessing_options={self._multiprocessing_options})"
)

def set_per_worker_buffer_size(self, per_worker_buffer_size: int):
if self._raw_iterator is None:
raise ValueError(
"Cannot change per worker buffer size before the iterator has been"
" initialized."
)
self._raw_iterator.set_per_worker_buffer_size(per_worker_buffer_size)


class ThreadPrefetchIterDataset(dataset.IterDataset[T]):
"""Iterable dataset that uses a synchronized queue for prefetching.
Expand Down
71 changes: 71 additions & 0 deletions grain/_src/python/dataset/transformations/prefetch_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,77 @@ def map_fn(x):
],
)

def test_set_per_worker_buffer_size_increase(self):
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=1
)
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
mp_options,
)
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
self.assertEqual(next(it), 1)
time.sleep(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
)
it.set_per_worker_buffer_size(2)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
)
self.assertEqual(next(it), 2)
self.assertEqual(list(it), list(range(3, 11)))

def test_set_per_worker_buffer_size_decrease(self):
ds = dataset.MapDataset.range(10).map(lambda x: x + 1).to_iter_dataset()
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=2
)
ds = prefetch.MultiprocessPrefetchIterDataset(
ds,
mp_options,
)
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, ds.__iter__())
self.assertEqual(next(it), 1)
time.sleep(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 2 # pytype: disable=attribute-error
)
it.set_per_worker_buffer_size(1)
self.assertEqual(
it._raw_iterator._multiprocessing_options.per_worker_buffer_size, 1 # pytype: disable=attribute-error
)
self.assertEqual(next(it), 2)
self.assertEqual(list(it), list(range(3, 11)))

def test_set_per_worker_buffer_size_to_trigger_error(self):
def f(x):
if x >= 5:
raise ValueError(f'x={x} is too large')
return x

ds = (
dataset.MapDataset.range(10)
.map(f)
.to_iter_dataset(
read_options=options.ReadOptions(prefetch_buffer_size=0)
)
)
mp_options = options.MultiprocessingOptions(
num_workers=1, per_worker_buffer_size=1
)
it = prefetch.MultiprocessPrefetchIterDataset(ds, mp_options).__iter__()
it = cast(prefetch._MultiprocessPrefetchDatasetIterator, it)
self.assertEqual(next(it), 0)
it.set_per_worker_buffer_size(10)
next(it)
time.sleep(3)
q = it._raw_iterator._reader_queue # pytype: disable=attribute-error
# Prefetching will end once an error is put into the reader queue. The
# elements 2, 3, 4 will be in the queue along with the error for 5.
self.assertEqual(q.qsize(), 4)


class ThreadPrefetchIterDatasetTest(parameterized.TestCase):

Expand Down
82 changes: 73 additions & 9 deletions grain/_src/python/grain_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
from grain._src.python import multiprocessing_common
from grain._src.python import record
from grain._src.python import shared_memory_array
from grain._src.python import variable_size_queue
from grain._src.python.options import MultiprocessingOptions # pylint: disable=g-importing-member


Expand Down Expand Up @@ -225,7 +226,7 @@ def _worker_loop(
*,
args_queue: queues.Queue,
errors_queue: queues.Queue,
output_queue: queues.Queue,
output_queue: variable_size_queue.VariableSizeMultiprocessingQueue,
termination_event: synchronize.Event,
start_profiling_event: synchronize.Event,
stop_profiling_event: synchronize.Event,
Expand Down Expand Up @@ -342,6 +343,9 @@ def __init__(
options: MultiprocessingOptions,
worker_init_fn: Callable[[int, int], None] | None = None,
stats_in_queues: tuple[queues.Queue, ...] | None = None,
worker_output_queues: list[
variable_size_queue.VariableSizeMultiprocessingQueue
],
):
"""Initialise a Grain Pool.

Expand All @@ -362,11 +366,13 @@ def __init__(
the total worker count.
stats_in_queues: Queue to propagate execution summary from child processes
to the parent.
worker_output_queues: list of queues for each worker to output elements
to.
"""
self.num_processes = options.num_workers
logging.info("Grain pool will use %i processes.", self.num_processes)
self.worker_args_queues = []
self.worker_output_queues = []
self.worker_output_queues = worker_output_queues
self.processes = []
# Reader termination should always result in worker termination. However,
# worker termination should not shut down the reader: workers are terminated
Expand Down Expand Up @@ -396,11 +402,10 @@ def __init__(

for worker_index in range(self.num_processes):
worker_args_queue = ctx.Queue(1)
worker_output_queue = ctx.Queue(options.per_worker_buffer_size)
process_kwargs = dict(
args_queue=worker_args_queue,
errors_queue=self.worker_error_queue,
output_queue=worker_output_queue,
output_queue=self.worker_output_queues[worker_index],
stats_out_queue=(
self.stats_in_queues[worker_index]
if self.stats_in_queues
Expand Down Expand Up @@ -434,7 +439,6 @@ def __init__(
target=_worker_loop, kwargs=process_kwargs, daemon=True
)
self.worker_args_queues.append(worker_args_queue)
self.worker_output_queues.append(worker_output_queue)
self.processes.append(process)

logging.info("Grain pool will start child processes.")
Expand Down Expand Up @@ -589,6 +593,27 @@ class _GrainPoolProcessingComplete:
]


class _ThreadPoolContainer:
"""Container for ThreadPool to allow replacing it."""

def __init__(self, processes: int):
self.pool = pool.ThreadPool(processes)

def apply_async(self, *args, **kwargs):
return self.pool.apply_async(*args, **kwargs)

def close(self):
self.pool.close()

def join(self):
self.pool.join()

def replace_pool(self, num_threads: int):
old_pool = self.pool
self.pool = pool.ThreadPool(num_threads)
old_pool.close()


def _open_shared_memory_for_leaf(element: Any) -> Any:
if isinstance(element, shared_memory_array.SharedMemoryArrayMetadata):
element = shared_memory_array.SharedMemoryArray.from_metadata(element)
Expand All @@ -610,6 +635,9 @@ def _process_elements_in_grain_pool(
get_element_producer_fn: GetElementProducerFn,
multiprocessing_options: MultiprocessingOptions,
reader_queue: queue.Queue[_QueueElement],
worker_output_queues: list[
variable_size_queue.VariableSizeMultiprocessingQueue
],
thread_pool: pool.ThreadPool,
termination_event: threading.Event,
start_profiling_event: synchronize.Event | None,
Expand All @@ -636,6 +664,7 @@ def read_thread_should_stop():
options=multiprocessing_options,
worker_init_fn=worker_init_fn,
stats_in_queues=stats_in_queues,
worker_output_queues=worker_output_queues,
) as g_pool:
for element in g_pool:
if read_thread_should_stop():
Expand Down Expand Up @@ -714,6 +743,7 @@ def __init__(
self._last_worker_index = worker_index_to_start_reading - 1
self._worker_init_fn = worker_init_fn
self._reader_queue = None
self._worker_output_queues = None
self._reader_thread_pool = None
self._termination_event = None
self._reader_thread = None
Expand All @@ -736,15 +766,26 @@ def start_prefetch(self) -> None:
self._multiprocessing_options.num_workers
* self._multiprocessing_options.per_worker_buffer_size
)
self._reader_queue = queue.Queue(maxsize=max_buffered_elements)
self._reader_thread_pool = pool.ThreadPool(max_buffered_elements)
self._reader_queue = variable_size_queue.VariableSizeQueue(
max_buffered_elements
)
self._reader_thread_pool = _ThreadPoolContainer(max_buffered_elements)
self._termination_event = threading.Event()
ctx = mp.get_context("spawn")
self._worker_output_queues = []
for _ in range(self._multiprocessing_options.num_workers):
self._worker_output_queues.append(
variable_size_queue.VariableSizeMultiprocessingQueue(
self._multiprocessing_options.per_worker_buffer_size, ctx
)
)
self._reader_thread = threading.Thread(
target=_process_elements_in_grain_pool,
kwargs=dict(
get_element_producer_fn=self._get_element_producer_fn,
multiprocessing_options=self._multiprocessing_options,
reader_queue=self._reader_queue,
worker_output_queues=self._worker_output_queues,
thread_pool=self._reader_thread_pool,
termination_event=self._termination_event,
start_profiling_event=self._start_profiling_event,
Expand Down Expand Up @@ -775,6 +816,7 @@ def stop_prefetch(self) -> None:
self._reader_thread_pool = None
self._reader_thread = None
self._reader_queue = None
self._worker_output_queues = None

def __enter__(self):
self.start_prefetch()
Expand Down Expand Up @@ -809,7 +851,7 @@ def __next__(self):
"MultiProcessIterator is in an invalid state. Note that"
" MultiProcessIterator should be used with a 'with' statement."
)
element = multiprocessing_common.get_element_from_queue(
element = multiprocessing_common.get_element_from_queue( # pytype: disable=wrong-arg-types
self._reader_queue, self._termination_event.is_set # pytype: disable=attribute-error
)
if isinstance(element, Exception):
Expand All @@ -826,9 +868,31 @@ def __next__(self):
)

result = multiprocessing_common.get_async_result(
element.async_result, self._termination_event.is_set
element.async_result, self._termination_event.is_set # pytype: disable=attribute-error
)
if isinstance(result, multiprocessing_common._SystemTerminated): # pylint: disable=protected-access
raise StopIteration
self._last_worker_index = element.worker_index
return result

def set_per_worker_buffer_size(self, per_worker_buffer_size: int):
"""Sets the per worker buffer size."""
if self._worker_output_queues is None or self._reader_queue is None:
raise ValueError(
"Cannot change per worker buffer size before the iterator has been"
" initialized."
)
for q in self._worker_output_queues:
q.set_max_size(per_worker_buffer_size)
self._reader_queue.set_max_size(
per_worker_buffer_size * self._multiprocessing_options.num_workers
)
self._multiprocessing_options = dataclasses.replace(
self._multiprocessing_options,
per_worker_buffer_size=per_worker_buffer_size,
)
new_thread_count = (
self._multiprocessing_options.num_workers
* self._multiprocessing_options.per_worker_buffer_size
)
self._reader_thread_pool.replace_pool(new_thread_count) # pytype: disable=attribute-error
Loading