Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/lightning/fabric/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,29 @@ def _check_bad_cuda_fork() -> None:
Lightning users.

"""
if not torch.cuda.is_initialized():
return

message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)
# Use PyTorch's internal check for bad fork state, which is more accurate than just checking if CUDA
# is initialized. This allows passive CUDA initialization (e.g., from library imports or device queries)
# while still catching actual problematic cases where CUDA context was created before forking.
_is_in_bad_fork = getattr(torch.cuda, "_is_in_bad_fork", None)
if _is_in_bad_fork is not None and callable(_is_in_bad_fork) and _is_in_bad_fork():
message = (
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, "
"you must use the 'spawn' start method or avoid CUDA initialization in the main process."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)

Copy link

Copilot AI Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace found at the end of the line. Remove the extra whitespace.

Copilot uses AI. Check for mistakes.
# Fallback to the old check if _is_in_bad_fork is not available (older PyTorch versions)
if _is_in_bad_fork is None and torch.cuda.is_initialized():
message = (
"Lightning can't create new processes if CUDA is already initialized. Did you manually call"
" `torch.cuda.*` functions, have moved the model to the device, or allocated memory on the GPU any"
" other way? Please remove any such calls, or change the selected strategy."
)
if _IS_INTERACTIVE:
message += " You will have to restart the Python kernel."
raise RuntimeError(message)


def _disable_module_memory_sharing(data: Any) -> Any:
Expand Down
11 changes: 11 additions & 0 deletions tests/tests_fabric/strategies/launchers/test_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ def test_check_for_bad_cuda_fork(mp_mock, _, start_method):
launcher.launch(function=Mock())


@pytest.mark.parametrize("start_method", ["fork", "forkserver"])
@mock.patch("torch.cuda._is_in_bad_fork", return_value=True)
@mock.patch("lightning.fabric.strategies.launchers.multiprocessing.mp")
def test_check_for_bad_cuda_fork_with_is_in_bad_fork(mp_mock, _, start_method):
"""Test the new _is_in_bad_fork detection when available."""
mp_mock.get_all_start_methods.return_value = [start_method]
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method=start_method)
with pytest.raises(RuntimeError, match="Cannot re-initialize CUDA in forked subprocess"):
launcher.launch(function=Mock())


def test_check_for_missing_main_guard():
launcher = _MultiProcessingLauncher(strategy=Mock(), start_method="spawn")
with (
Expand Down