Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions fast_llm/core/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
logger = logging.getLogger(__name__)


def _get_device(group: ProcessGroup) -> torch.device:
if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL):
return torch.device(torch.cuda.current_device())
elif isinstance(group, torch.distributed.ProcessGroupGloo):
return torch.device("cpu")
else:
raise NotImplementedError(type(group))


@contextlib.contextmanager
def set_timeout(group: ProcessGroup | None, timeout: float | None = None):
if group is not None and timeout is not None:
Expand Down Expand Up @@ -72,12 +81,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name:
)


def safe_barrier(
group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None
) -> None:
def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None:
if group:
hashed = hash(value) % 2**32
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device)
out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout)
if out != hashed * group.size():
raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})")

Expand All @@ -88,10 +95,9 @@ def allreduce_scalar(
group: torch.distributed.ProcessGroup | None = None,
op=ReduceOp.SUM,
timeout: float | None = None,
device: torch.device | None = None,
) -> float | int:
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device)
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
with set_timeout(group, timeout):
torch.distributed.all_reduce(value, op=op, group=group)
return value.item()
Expand All @@ -106,7 +112,7 @@ def all_gather_scalar(
timeout: float | None = None,
):
if group:
value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device())
value = torch.full([1], value, dtype=dtype, device=_get_device(group))
output_tensor = value.new_empty((group.size(),))
with set_timeout(group, timeout):
torch.distributed.all_gather_into_tensor(output_tensor, value, group=group)
Expand All @@ -116,15 +122,15 @@ def all_gather_scalar(


def broadcast_scalar(
value: float | int,
value: float | int | None,
dtype: torch.dtype = torch.float64,
group: torch.distributed.ProcessGroup | None = None,
src: int = 0,
timeout: float | None = None,
) -> float | int:
if not group:
return value
tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device()))
tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group)))
if group.rank() == src:
tensor.fill_(value)
broadcast(tensor, src, group, timeout=timeout)
Expand All @@ -141,14 +147,14 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None
if group.rank() == src:
tensor = _object_to_tensor(input_object)
size = tensor.numel()
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast_tensor.copy_(tensor)
broadcast_scalar(size, torch.int64, group, src)
broadcast(broadcast_tensor, src, group)
return input_object
else:
size = int(broadcast_scalar(None, torch.int64, group, src))
output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device())
output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group))
broadcast(output_tensor, src, group)
return _tensor_to_object(output_tensor)

Expand Down
2 changes: 1 addition & 1 deletion fast_llm/data/dataset/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def sample(self, config: "SamplingData") -> SampledDataset[SampleType]:

class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]):
@abc.abstractmethod
def __iter__(self) -> typing.Iterator[SampleType]:
def iterate(self, sampling: "SamplingData") -> typing.Iterator[SampleType]:
pass

def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]":
Expand Down
1 change: 0 additions & 1 deletion fast_llm/data/dataset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,6 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp


REDIS_DATA_STREAM = "fast_llm_streaming"
REDIS_FIELD = "data"
REDIS_GROUP_NAME = "fast_llm_group"


Expand Down
11 changes: 5 additions & 6 deletions fast_llm/data/dataset/sampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,10 @@ def __init__(
sampling: SamplingData,
):
self._dataset = dataset
self._config = sampling.config
self._parameters = sampling.parameters
self._sampling = sampling
self._documents: list[SampleType] = []
self._current_length = 0
self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens
self._sample_length = self._sampling.parameters.sequence_length + self._sampling.parameters.extra_tokens
# Delay iterator creation to avoid pickling issues.
self._iterator: typing.Iterator[SampleType] | None = None

Expand All @@ -458,7 +457,7 @@ def requires_broadcast(self) -> bool:

def __getitem__(self, index: int) -> SampleType:
if self._iterator is None:
self._iterator = iter(self._dataset)
self._iterator = self._dataset.iterate(self._sampling)
while self._current_length < self._sample_length:
document = next(self._iterator)
if len(document) > self._sample_length:
Expand All @@ -474,7 +473,7 @@ def __getitem__(self, index: int) -> SampleType:
else:
last_length = len(self._documents[-1])
remaining_length = last_length - (self._current_length - self._sample_length)
if self._parameters.truncate_documents:
if self._sampling.parameters.truncate_documents:
documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)]
self._documents = [self._documents[-1].crop(remaining_length, last_length)]
else:
Expand All @@ -486,7 +485,7 @@ def __getitem__(self, index: int) -> SampleType:
return sample

def __len__(self) -> int:
return self._parameters.num_samples
return self._sampling.parameters.num_samples

@property
def name(self) -> str:
Expand Down
134 changes: 101 additions & 33 deletions fast_llm/data/dataset/streaming.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,112 @@
import functools
import json
import typing

import redis
import torch.utils.data

from fast_llm.config import Configurable
from fast_llm.config import Config, Configurable, Field, config_class
from fast_llm.data.dataset.abstract import SamplableIterableDataset
from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig
from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, SamplingData, StreamingDatasetConfig
from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig
from fast_llm.data.sample.language_model import LanguageModelSample
from fast_llm.data.sample.range import RangeSample
from fast_llm.data.sample.token import TokenSample
from fast_llm.data.sample.token_data import TokenDataSample
from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames
from fast_llm.utils import Assert


def dtype_from_string(name: str) -> torch.dtype:
try:
return getattr(torch, name)
except AttributeError:
raise ValueError(f"Unknown torch dtype: {name}")
@config_class()
class RedisDocument(Config):
"""
Schema for sending and receiving documents through redis, and the associated handling code.
"""

tokens: torch.Tensor = Field()
loss_masking_spans: list[tuple[int, int]] | None = Field(default=None)
chosen_span: tuple[int, int] | None = Field(default=None)
rejected_span: tuple[int, int] | None = Field(default=None)
advantage: float | None = Field(default=None)
old_log_probabilities: torch.Tensor | None = Field(default=None)

def _validate(self):
# Decode message
if isinstance(self.tokens, bytes):
self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64)
elif isinstance(self.tokens, (list, tuple)):
self.tokens = torch.tensor(self.tokens, dtype=torch.int64)
if isinstance(self.loss_masking_spans, str):
self.loss_masking_spans = json.loads(self.loss_masking_spans)
if isinstance(self.chosen_span, str):
self.chosen_span = json.loads(self.chosen_span)
if isinstance(self.rejected_span, str):
self.rejected_span = json.loads(self.rejected_span)
if isinstance(self.old_log_probabilities, bytes):
self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32)
elif isinstance(self.old_log_probabilities, (list, tuple)):
self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32)
super()._validate()
if self.old_log_probabilities is not None:
Assert.eq(len(self.old_log_probabilities), self.num_tokens)

@functools.cached_property
def num_tokens(self) -> int:
return len(self.tokens)

@classmethod
def from_message(cls, message: dict[bytes, bytes]) -> typing.Self:
# Read
kwargs = {}
for key, value in message.items():
key = key.decode()
if key == "data":
kwargs.update(json.loads(value))
else:
kwargs[key] = value
return cls.from_dict(kwargs)

def to_message(self) -> dict[str, str | int | float | bytes]:
# Encode message
message: dict[str, str | int | float | bytes] = {"tokens": self.tokens.numpy().tobytes()}
if self.old_log_probabilities is not None:
message["old_log_probabilities"] = self.old_log_probabilities.numpy().tobytes()
data = {}
if self.loss_masking_spans is not None:
data["loss_masking_spans"] = self.loss_masking_spans
if self.chosen_span is not None:
data["chosen_span"] = self.chosen_span
if self.rejected_span is not None:
data["rejected_span"] = self.rejected_span
if self.advantage is not None:
data["advantage"] = self.advantage
if data:
message["data"] = json.dumps(data)
return message

def to_sample(self, preprocessing: LanguageModelPreprocessingConfig | None):
sample_size = len(self.tokens)
# TODO: Check explicitly that required data is available?
return LanguageModelSample(
tokens=TokenSample(self.tokens, [sample_size]),
loss_masking_spans=(
RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size)
if preprocessing.use_loss_masking_spans
else None
),
chosen_spans=RangeSample([self.chosen_span], sample_size) if preprocessing.use_preference_spans else None,
rejected_spans=(
RangeSample([self.rejected_span], sample_size) if preprocessing.use_preference_spans else None
),
advantages=(
TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32))
if preprocessing.use_advantages
else None
),
old_log_probabilities=(
TokenDataSample(self.old_log_probabilities) if preprocessing.use_old_log_probabilities else None
),
)


class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample](
Expand All @@ -40,7 +129,7 @@ def requires_broadcast(self) -> bool:
def name(self) -> str:
return self._name

def __iter__(self) -> typing.Iterator[LanguageModelSample]:
def iterate(self, sampling: SamplingData) -> typing.Iterator[SampleType]:
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None and worker_info.num_workers > 1:
raise RuntimeError("StreamingDataset can work only with one instance per rank")
Expand Down Expand Up @@ -77,29 +166,8 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]:
noack=True,
)
if messages:
for stream_key, msgs in messages:
for stream_key, messages_ in messages:
assert stream_key == REDIS_DATA_STREAM.encode()
for msg_id, msg_data in msgs:
yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()]))

def _read_document(self, data: dict) -> LanguageModelSample:
tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"]))
sample_size = len(tokens)
if "loss_masking_spans" in data:
loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size)
else:
loss_masking_spans = None
if "chosen_spans" in data:
chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size)
else:
chosen_spans = None
if "rejected_spans" in data:
rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size)
else:
rejected_spans = None
return LanguageModelSample(
TokenSample(tokens, [sample_size]),
loss_masking_spans,
chosen_spans,
rejected_spans,
)
for message_id, message in messages_:
print(message)
yield RedisDocument.from_message(message).to_sample(sampling.preprocessing)
15 changes: 15 additions & 0 deletions fast_llm/data/preparator/gpt_memmap/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def has_preference_spans(self) -> bool:
def has_images(self) -> bool:
return False

@functools.cached_property
def has_grpo_data(self) -> bool:
return False


@config_class(dynamic_type={LanguageModelSourceConfig: "document"})
class DocumentSourceConfig(LanguageModelSourceConfig):
Expand Down Expand Up @@ -91,6 +95,13 @@ class DocumentSourceConfig(LanguageModelSourceConfig):
desc="Field containing image positions in the text.",
hint=FieldHint.optional,
)
# TODO: Old log probabilities are made up (zeros) since we don't know the token count in advance.
advantages: str | None = Field(
default=None,
desc="Field containing advantaged for policy optimization."
" Mainly for debugging purposed as advantages are typically generated at runtime.",
hint=FieldHint.optional,
)

@functools.cached_property
def columns(self) -> list[str]:
Expand All @@ -117,6 +128,10 @@ def has_images(self) -> bool:
Assert.eq(self.images is None, self.image_positions is None)
return self.images is not None

@functools.cached_property
def has_grpo_data(self) -> bool:
return self.advantages is not None

def _validate(self):
super()._validate()
if self.has_preference_spans and self.has_loss_masking_span:
Expand Down
Loading