Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import torch
import torch.monitor
from torch._C._distributed_c10d import Work
from torch.distributed import ( # noqa
ProcessGroup,
ReduceOp,
Expand All @@ -29,6 +28,15 @@
logger = logging.getLogger(__name__)


def _get_device(group: ProcessGroup) -> torch.device:
if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL):
return torch.device(torch.cuda.current_device())
elif isinstance(group, torch.distributed.ProcessGroupGloo):
return torch.device("cpu")
else:
raise NotImplementedError(type(group))


@contextlib.contextmanager
def set_timeout(group: ProcessGroup | None, timeout: float | None = None):
if group is not None and timeout is not None:
Expand All @@ -42,7 +50,7 @@ def set_timeout(group: ProcessGroup | None, timeout: float | None = None):

def broadcast(
tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None
) -> Work | None:
) -> torch.distributed.Work | None:
"""Same as torch.distributed.broadcast, but without the complication of going through the global rank."""
assert group is not None
opts = torch.distributed.BroadcastOptions()
Expand Down Expand Up @@ -72,12 +80,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(
group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None
) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")

Expand All @@ -88,10 +94,9 @@ def allreduce_scalar(
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
timeout: float | None = None,
device: torch.device | None = None,
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device)
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
with set_timeout(group, timeout):
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
Expand All @@ -106,7 +111,7 @@ def all_gather_scalar(
timeout: float | None = None,
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
output_tensor = value.new_empty((group.size(),))
with set_timeout(group, timeout):
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
Expand All @@ -116,15 +121,15 @@ def all_gather_scalar(


def broadcast_scalar(
value: float | int,
value: float | int | None,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group)))
if group.rank() == src:
tensor.fill_(value)
broadcast(tensor, src, group, timeout=timeout)
Expand All @@ -141,19 +146,21 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None
if group.rank() == src:
tensor = _object_to_tensor(input_object)
size = tensor.numel()
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast_tensor.copy_(tensor)
broadcast_scalar(size, torch.int64, group, src)
broadcast(broadcast_tensor, src, group)
return input_object
else:
size = int(broadcast_scalar(None, torch.int64, group, src))
output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast(output_tensor, src, group)
return _tensor_to_object(output_tensor)


def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
def send(
tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0
) -> torch.distributed.Work | None:
assert group is not None
if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
# send not supported for gloo on GPU.
Expand All @@ -169,7 +176,9 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta
return None


def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None:
def recv(
tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0
) -> torch.distributed.Work | None:
assert group is not None
if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu":
# recv not supported for gloo on GPU.
Expand Down
7 changes: 3 additions & 4 deletions fast_llm/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
import torch._dynamo # noqa
import torch.autograd
from torch._C._distributed_c10d import Work

from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_gather_into_tensor, all_reduce, reduce_scatter_tensor
from fast_llm.utils import Assert, div
Expand All @@ -18,7 +17,7 @@

def reduce_op(
input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor:
if group:
handle = all_reduce(input_, group=group, async_op=async_op, op=op)
else:
Expand Down Expand Up @@ -62,7 +61,7 @@ def swap_mult_dim(tensor: torch.Tensor, factor: int, old_dim: int, new_dim: int)

def gather_op(
input_: torch.Tensor, group: ProcessGroup | None, dim: int, async_op: bool = False, out=None
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor:
"""Gather tensors and concatenate along the last dimension."""
# Bypass the function if we are using only 1 GPU.
if not group:
Expand All @@ -89,7 +88,7 @@ def reduce_scatter_op(
op: ReduceOp = ReduceOp.SUM,
dim: int = 0,
async_op: bool = False,
) -> tuple[torch.Tensor, Work] | torch.Tensor:
) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor:
"""Reduce-scatter the input tensor across model parallel group."""
# Bypass the function if we are using only 1 GPU.
if not group:
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion fast_llm/data/data/gpt/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from fast_llm.core.distributed import safe_barrier
from fast_llm.data.data.abstract import Data
from fast_llm.data.data.data_loader import SampledDatasetIterator
from fast_llm.data.data.gpt.config import GPTDataConfig
from fast_llm.data.data_loader import SampledDatasetIterator
from fast_llm.data.dataset.abstract import SampledDataset
from fast_llm.data.dataset.config import SamplingParameters
from fast_llm.data.dataset.gpt.config import GPTSamplingData
Expand Down
22 changes: 16 additions & 6 deletions fast_llm/data/preprocessing/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
if typing.TYPE_CHECKING:
import numpy as np
import torch
import transformers


@config_class(dynamic_type={PreprocessingConfig: "tokenizer"})
Expand Down Expand Up @@ -52,7 +53,7 @@ def __init__(self, config: ConfigType):
from transformers import AutoTokenizer

log_main_rank(f"> loading tokenizer from {config.path} ...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.tokenizer: "transformers.PreTrainedTokenizer" = AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=self._config.path,
errors="replace",
max_len=None,
Expand All @@ -70,10 +71,15 @@ def __init__(self, config: ConfigType):

@functools.cached_property
def vocab_size(self) -> int:
out = len(self.tokenizer)
if self._config.max_vocab_size is not None:
out = min(out, self._config.max_vocab_size)
return out
return (
self._tokenizer_vocab_size
if self._config.max_vocab_size is None
else min(self._tokenizer_vocab_size, self._config.max_vocab_size)
)

@functools.cached_property
def _tokenizer_vocab_size(self) -> int:
return len(self.tokenizer)

@property
def vocab(self) -> dict[str, int]:
Expand All @@ -99,7 +105,11 @@ def tokenize(
tokens = (
torch.tensor(
tokens,
dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch,
dtype=(
torch.int64
if self._tokenizer_vocab_size > torch.iinfo(data_type.torch).max
else data_type.torch
),
)
% self._config.max_vocab_size
).to(data_type.torch)
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/sample/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __len__(self) -> int:
return self.sample_size

def get_padding(self, size: int) -> typing.Self:
return PatchSample(
return self.__class__(
self.patches.new_empty((0, *self.patches.shape[1:])),
self.token_map.new_empty(0),
self.positions.new_empty([0, self.patches.ndim - 2]),
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/sample/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __len__(self) -> int:
return self.sample_size

def get_padding(self, size: int) -> typing.Self:
return RangeSample([], size)
return self.__class__([], size)


class RangeBatch(Batch):
Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/sample/token.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __len__(self) -> int:
return len(self.tokens)

def get_padding(self, size: int) -> typing.Self:
return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size])
return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size])


class TokenBatch(Batch):
Expand Down
3 changes: 2 additions & 1 deletion fast_llm/engine/checkpoint/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,12 @@ class CheckpointSaveConfigBase(CheckpointConfigBase):

@config_class()
class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase):
_abstract = False
model_weights: bool = FieldUpdate(desc="Save the model weights.")
optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.")

def _validate(self) -> None:
if self.optimizer_state is None:
if self.optimizer_state is None and hasattr(self.format, "support_optimizer"):
with self._set_implicit_default():
# TODO: Make sure it's a type
self.optimizer_state = self.format.support_optimizer
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/multi_stage/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing

import torch
from torch._C._distributed_c10d import ReduceOp
from torch.distributed import all_reduce, reduce_scatter_tensor

from fast_llm.core.distributed import ProcessGroup
Expand Down Expand Up @@ -398,7 +397,7 @@ def reduce_gradients(
out,
self._grad_buffer,
group=self._fsdp_group,
op=ReduceOp.AVG,
op=torch.distributed.ReduceOp.AVG,
)
if accumulate:
triton_add(self._grad_shard, out, self._grad_shard)
Expand Down
3 changes: 1 addition & 2 deletions fast_llm/engine/multi_stage/multi_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import warnings

import torch
from torch._C._distributed_c10d import ProcessGroup

from fast_llm.config import Configurable
from fast_llm.engine.base_model.base_model import BaseModel
Expand Down Expand Up @@ -611,7 +610,7 @@ class TiedParameter:
# Whether the local rank is involved at all.
on_device: bool
# Process group for reduction.
group: ProcessGroup | None = dataclasses.field(repr=False, init=False)
group: torch.distributed.ProcessGroup | None = dataclasses.field(repr=False, init=False)
all_ranks: set[int]
# The index of the main stage.
main_stage: int
Expand Down
Loading
Loading