Summary
Reader(gpu='cuda:N') (or any non-default CUDA device) crashes whenever the host process is not using cuda:0. The recognizer is unconditionally wrapped in torch.nn.DataParallel(model) without device_ids, so DataParallel falls back to all visible GPUs and pins the primary to cuda:0. On a multi-rank training job (DDP/FSDP), every rank with local_rank != 0 errors out.
This is the same root cause as #295 (closed in 2022 without fix). It is still present in the current master.
Reproduction
import torch
import easyocr
# Pretend we're rank 1 in a multi-rank job — current device is cuda:1.
torch.cuda.set_device(1)
reader = easyocr.Reader(['en'], gpu='cuda:1') # also fails with gpu=True
img = '<any image>'
reader.readtext(img)
Crash:
RuntimeError: module must have its parameters and buffers on device cuda:0
(device_ids[0]) but found one of them on device: cuda:1
We hit this on a 4×GB200 node running an FSDP fine-tune; ranks 1/2/3 all errored on every frame.
Root cause
In easyocr/recognition.py:
model = torch.nn.DataParallel(model).to(device)
(also easyocr/detection.py per #295)
Without device_ids=[idx], DataParallel picks range(torch.cuda.device_count()) and treats device_ids[0] (cuda:0) as authoritative. The subsequent .to(device) doesn't override DataParallel's device_ids.
Reader.__init__ already supports gpu='cuda:N' (it goes straight into self.device), but the DP wrap nullifies it.
Suggested fix
Either:
-
Drop the DataParallel wrap entirely (the model is already moved to the right device with .to(device); DP makes no sense for inference and conflicts with parent-process distributed training):
-
Honor the requested device when DP is kept:
device_ids = [torch.device(device).index] if 'cuda' in str(device) else None
model = torch.nn.DataParallel(model, device_ids=device_ids).to(device)
Option 1 is preferable — DataParallel is deprecated by PyTorch in favor of DDP, and inference doesn't benefit from it.
Workaround (current)
Stub nn.DataParallel to a passthrough during Reader init so the model loads straight onto the rank-local device:
import torch, easyocr
orig_dp = torch.nn.DataParallel
torch.nn.DataParallel = lambda m, *a, **kw: m
try:
reader = easyocr.Reader(['en'], gpu=f'cuda:{torch.cuda.current_device()}')
finally:
torch.nn.DataParallel = orig_dp
Happy to send a PR if a maintainer confirms which option is preferred.
Summary
Reader(gpu='cuda:N')(or any non-default CUDA device) crashes whenever the host process is not usingcuda:0. The recognizer is unconditionally wrapped intorch.nn.DataParallel(model)withoutdevice_ids, so DataParallel falls back to all visible GPUs and pins the primary tocuda:0. On a multi-rank training job (DDP/FSDP), every rank withlocal_rank != 0errors out.This is the same root cause as #295 (closed in 2022 without fix). It is still present in the current
master.Reproduction
Crash:
We hit this on a 4×GB200 node running an FSDP fine-tune; ranks 1/2/3 all errored on every frame.
Root cause
In
easyocr/recognition.py:(also
easyocr/detection.pyper #295)Without
device_ids=[idx], DataParallel picksrange(torch.cuda.device_count())and treatsdevice_ids[0](cuda:0) as authoritative. The subsequent.to(device)doesn't override DataParallel'sdevice_ids.Reader.__init__already supportsgpu='cuda:N'(it goes straight intoself.device), but the DP wrap nullifies it.Suggested fix
Either:
Drop the DataParallel wrap entirely (the model is already moved to the right device with
.to(device); DP makes no sense for inference and conflicts with parent-process distributed training):Honor the requested device when DP is kept:
Option 1 is preferable — DataParallel is deprecated by PyTorch in favor of DDP, and inference doesn't benefit from it.
Workaround (current)
Stub
nn.DataParallelto a passthrough duringReaderinit so the model loads straight onto the rank-local device:Happy to send a PR if a maintainer confirms which option is preferred.