From 036eaa42b95a741d1a4fc06f3b14199a245dde39 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 22 Jan 2026 23:28:16 -0500 Subject: [PATCH 1/4] stuff --- fast_llm/data/dataset/config.py | 1 - fast_llm/data/dataset/streaming.py | 128 +++++++++--- fast_llm/data/preprocessing/language_model.py | 2 + fast_llm/data/sample/language_model.py | 93 ++++++++- fast_llm/data/sample/patch.py | 2 +- fast_llm/data/sample/range.py | 2 +- fast_llm/data/sample/token.py | 2 +- fast_llm/data/sample/token_data.py | 190 ++++++++++++++++++ tests/data/test_streaming.py | 68 +++++-- tests/utils/redis.py | 8 +- 10 files changed, 436 insertions(+), 60 deletions(-) create mode 100644 fast_llm/data/sample/token_data.py diff --git a/fast_llm/data/dataset/config.py b/fast_llm/data/dataset/config.py index ede450dfa..8b18b59ba 100644 --- a/fast_llm/data/dataset/config.py +++ b/fast_llm/data/dataset/config.py @@ -302,7 +302,6 @@ def build(self, preprocessing: PreprocessingConfig) -> "IndexedDataset[SampleTyp REDIS_DATA_STREAM = "fast_llm_streaming" -REDIS_FIELD = "data" REDIS_GROUP_NAME = "fast_llm_group" diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index c261e383e..48275988f 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -1,23 +1,108 @@ +import functools import json import typing import redis import torch.utils.data -from fast_llm.config import Configurable +from fast_llm.config import Config, Configurable, Field, config_class from fast_llm.data.dataset.abstract import SamplableIterableDataset -from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_FIELD, REDIS_GROUP_NAME, StreamingDatasetConfig +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, StreamingDatasetConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.token_data import TokenDataSample from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.utils import Assert -def dtype_from_string(name: str) -> torch.dtype: - try: - return getattr(torch, name) - except AttributeError: - raise ValueError(f"Unknown torch dtype: {name}") +@config_class() +class RedisDocument(Config): + """ + Schema for sending and receiving documents through redis, and the associated handling code. + """ + + tokens: torch.Tensor = Field() + loss_masking_spans: list[tuple[int, int]] | None = Field(default=None) + chosen_span: tuple[int, int] | None = Field(default=None) + rejected_span: tuple[int, int] | None = Field(default=None) + advantage: float | None = Field(default=None) + old_log_probabilities: torch.Tensor | None = Field(default=None) + + def _validate(self): + # Decode message + if isinstance(self.tokens, bytes): + self.tokens = torch.frombuffer(self.tokens, dtype=torch.int64) + elif isinstance(self.tokens, (list, tuple)): + self.tokens = torch.tensor(self.tokens, dtype=torch.int64) + if isinstance(self.loss_masking_spans, str): + self.loss_masking_spans = json.loads(self.loss_masking_spans) + if isinstance(self.chosen_span, str): + self.chosen_span = json.loads(self.chosen_span) + if isinstance(self.rejected_span, str): + self.rejected_span = json.loads(self.rejected_span) + if isinstance(self.old_log_probabilities, bytes): + self.old_log_probabilities = torch.frombuffer(self.old_log_probabilities, dtype=torch.float32) + elif isinstance(self.old_log_probabilities, (list, tuple)): + self.old_log_probabilities = torch.tensor(self.old_log_probabilities, dtype=torch.float32) + super()._validate() + if self.old_log_probabilities is not None: + Assert.eq(len(self.old_log_probabilities), self.num_tokens) + + @functools.cached_property + def num_tokens(self) -> int: + return len(self.tokens) + + @classmethod + def from_message(cls, message: dict[bytes, bytes]) -> typing.Self: + # Read + kwargs = {} + for key, value in message.items(): + key = key.decode() + if key == "data": + kwargs.update(json.loads(value)) + else: + kwargs[key] = value + return cls.from_dict(kwargs) + + def to_message(self) -> dict[str, str | int | float | bytes]: + # Encode message + message: dict[str, str | int | float | bytes] = {"tokens": self.tokens.numpy().tobytes()} + if self.old_log_probabilities is not None: + message["old_log_probabilities"] = self.old_log_probabilities.numpy().tobytes() + data = {} + if self.loss_masking_spans is not None: + data["loss_masking_spans"] = self.loss_masking_spans + if self.chosen_span is not None: + data["chosen_span"] = self.chosen_span + if self.rejected_span is not None: + data["rejected_span"] = self.rejected_span + if self.advantage is not None: + data["advantage"] = self.advantage + if data: + message["data"] = json.dumps(data) + return message + + def to_sample(self): + sample_size = len(self.tokens) + return LanguageModelSample( + tokens=TokenSample(self.tokens, [sample_size]), + loss_masking_spans=( + None + if self.loss_masking_spans is None + else RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size) + ), + chosen_spans=None if self.chosen_span is None else RangeSample([self.chosen_span], sample_size), + rejected_spans=None if self.rejected_span is None else RangeSample([self.rejected_span], sample_size), + advantages=( + None + if self.advantage is None + else TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32)) + ), + old_log_probabilities=( + None if self.old_log_probabilities is None else TokenDataSample(self.old_log_probabilities) + ), + ) class RedisStreamingDataset[ConfigType: StreamingDatasetConfig, SampleType: LanguageModelSample]( @@ -77,29 +162,8 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: noack=True, ) if messages: - for stream_key, msgs in messages: + for stream_key, messages_ in messages: assert stream_key == REDIS_DATA_STREAM.encode() - for msg_id, msg_data in msgs: - yield self._read_document(json.loads(msg_data[REDIS_FIELD.encode()])) - - def _read_document(self, data: dict) -> LanguageModelSample: - tokens = torch.tensor(data["tokens"], dtype=dtype_from_string(data["tokens_dtype"])) - sample_size = len(tokens) - if "loss_masking_spans" in data: - loss_masking_spans = RangeSample([(begin, end) for begin, end in data["loss_masking_spans"]], sample_size) - else: - loss_masking_spans = None - if "chosen_spans" in data: - chosen_spans = RangeSample([(begin, end) for begin, end in data["chosen_spans"]], sample_size) - else: - chosen_spans = None - if "rejected_spans" in data: - rejected_spans = RangeSample([(begin, end) for begin, end in data["rejected_spans"]], sample_size) - else: - rejected_spans = None - return LanguageModelSample( - TokenSample(tokens, [sample_size]), - loss_masking_spans, - chosen_spans, - rejected_spans, - ) + for message_id, message in messages_: + print(message) + yield RedisDocument.from_message(message).to_sample() diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index d54776eec..87d176663 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -22,6 +22,8 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) + use_advantages: bool = Field(default=False) + use_old_log_probabilities: bool = Field(default=False) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index 22b89acf1..e3dab9bc2 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -39,6 +39,7 @@ RangeWriter, ) from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter +from fast_llm.data.sample.token_data import TokenDataBatch, TokenDataReader, TokenDataReaderConfig, TokenDataSample from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -53,12 +54,16 @@ def __init__( chosen_spans: RangeSample | None = None, rejected_spans: RangeSample | None = None, image_patches: PatchSample | None = None, + advantages: TokenDataSample | None = None, + old_log_probabilities: TokenDataSample | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: @@ -68,6 +73,10 @@ def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: _merge_optional(RangeSample.from_documents, [document.chosen_spans for document in documents]), _merge_optional(RangeSample.from_documents, [document.rejected_spans for document in documents]), _merge_optional(PatchSample.from_documents, [document.image_patches for document in documents]), + _merge_optional(TokenDataSample.from_documents, [document.advantages for document in documents]), + _merge_optional( + TokenDataSample.from_documents, [document.old_log_probabilities for document in documents] + ), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -77,18 +86,22 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return LanguageModelSample( + return self.__class__( self.tokens.get_padding(size), None if self.loss_masking_spans is None else self.loss_masking_spans.get_padding(size), None if self.chosen_spans is None else self.chosen_spans.get_padding(size), None if self.rejected_spans is None else self.rejected_spans.get_padding(size), None if self.image_patches is None else self.image_patches.get_padding(size), + None if self.advantages is None else self.advantages.get_padding(size), + None if self.old_log_probabilities is None else self.old_log_probabilities.get_padding(size), ) @@ -100,12 +113,16 @@ def __init__( chosen_spans: RangeBatch | None = None, rejected_spans: RangeBatch | None = None, image_patches: PatchBatch | None = None, + advantages: TokenDataBatch | None = None, + old_log_probabilities: TokenDataBatch | None = None, ): self.tokens = tokens self.loss_masking_spans = loss_masking_spans self.chosen_spans = chosen_spans self.rejected_spans = rejected_spans self.image_patches = image_patches + self.advantages = advantages + self.old_log_probabilities = old_log_probabilities @classmethod def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.Self: @@ -115,6 +132,8 @@ def from_samples(cls, samples: typing.Iterable[LanguageModelSample]) -> typing.S _merge_optional(RangeBatch.from_samples, [sample.chosen_spans for sample in samples]), _merge_optional(RangeBatch.from_samples, [sample.rejected_spans for sample in samples]), _merge_optional(PatchBatch.from_samples, [sample.image_patches for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.advantages for sample in samples]), + _merge_optional(TokenDataBatch.from_samples, [sample.old_log_probabilities for sample in samples]), ) def crop(self, begin: int, end: int) -> typing.Self: @@ -124,6 +143,8 @@ def crop(self, begin: int, end: int) -> typing.Self: _crop_optional(self.chosen_spans, begin, end), _crop_optional(self.rejected_spans, begin, end), _crop_optional(self.image_patches, begin, end), + _crop_optional(self.advantages, begin, end), + _crop_optional(self.old_log_probabilities, begin, end), ) def to_device_(self, device: "torch.device | str"): @@ -136,6 +157,10 @@ def to_device_(self, device: "torch.device | str"): self.rejected_spans.to_device_(device) if self.image_patches is not None: self.image_patches.to_device_(device) + if self.advantages is not None: + self.advantages.to_device_(device) + if self.old_log_probabilities is not None: + self.old_log_probabilities.to_device_(device) def _merge_optional[T](fn: typing.Callable[[typing.Iterable], T], args: typing.Iterable) -> T | None: @@ -157,6 +182,8 @@ class LanguageModelReaderConfig(MemmapIndexDatasetReaderConfig): chosen_spans: MemmapReaderBaseConfig = Field() rejected_spans: MemmapReaderBaseConfig = Field() image_patches: MemmapReaderBaseConfig = Field() + advantages: MemmapReaderBaseConfig = Field() + old_log_probabilities: MemmapReaderBaseConfig = Field() def _validate(self) -> None: super()._validate() @@ -192,6 +219,16 @@ def _validate(self) -> None: self.rejected_spans, RangeReaderConfig if self.preprocessing.use_preference_spans else NullReaderConfig, ) + Assert.custom( + isinstance, + self.advantages, + TokenDataReaderConfig if self.preprocessing.use_advantages else NullReaderConfig, + ) + Assert.custom( + isinstance, + self.old_log_probabilities, + TokenDataReaderConfig if self.preprocessing.use_old_log_probabilities else NullReaderConfig, + ) if self.preprocessing.use_image_patches: Assert.custom(isinstance, self.image_patches, PatchReaderConfig) Assert.eq(self.image_patches.patch_shape, self.preprocessing.image_patches.patch_shape) @@ -222,6 +259,8 @@ def _expected_buffer_size(self) -> int: + self.chosen_spans.expected_buffer_size + self.rejected_spans.expected_buffer_size + self.image_patches.expected_buffer_size + + self.advantages.expected_buffer_size + + self.old_log_probabilities.expected_buffer_size ) def get_metadata(self) -> dict[str, typing.Any]: @@ -235,6 +274,10 @@ def get_metadata(self) -> dict[str, typing.Any]: out["rejected_spans"] = self.rejected_spans.get_metadata() if not isinstance(self.image_patches, NullReaderConfig): out["image_patches"] = self.image_patches.get_metadata() + if not isinstance(self.advantages, NullReaderConfig): + out["advantages"] = self.advantages.get_metadata() + if not isinstance(self.old_log_probabilities, NullReaderConfig): + out["old_log_probabilities"] = self.old_log_probabilities.get_metadata() return out @classmethod @@ -257,6 +300,12 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi out["image_patches"] = PatchReaderConfig.blend_metadata( [metadata_["image_patches"] for metadata_ in metadata] ) + if "advantages" in metadata[0]: + out["advantages"] = RangeReaderConfig.blend_metadata([metadata_["advantages"] for metadata_ in metadata]) + if "old_log_probabilities" in metadata[0]: + out["old_log_probabilities"] = RangeReaderConfig.blend_metadata( + [metadata_["old_log_probabilities"] for metadata_ in metadata] + ) return out @@ -290,6 +339,10 @@ def __init__( self._chosen_spans = self._config.chosen_spans.get_reader(buffer) self._rejected_spans = self._config.rejected_spans.get_reader(buffer) + if self._model_preprocessing.use_advantages: + self._advantages = self._config.advantages.get_reader(buffer) + self._old_log_probabilities = self._config.old_log_probabilities.get_reader(buffer) + if self._model_preprocessing.use_image_patches: model_image_preprocessing: ImagePatchConfig = self._model_preprocessing.image_patches if isinstance(self._config.image_patches, NullReaderConfig): @@ -334,6 +387,12 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: else None ), image_patches, + (self._advantages.get_document(index, begin, end) if self._model_preprocessing.use_advantages else None), + ( + self._old_log_probabilities.get_document(index, begin, end) + if self._model_preprocessing.use_old_log_probabilities + else None + ), ) def get_document_sizes(self) -> torch.Tensor: @@ -356,6 +415,10 @@ def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dic metadata["rejected_spans"] = self._rejected_spans.get_split(begin_index, end_index) if hasattr(self, "_image_patches") and isinstance(self._image_patches, PatchReader): metadata["image_patches"] = self._image_patches.get_split(begin_index, end_index) + if hasattr(self, "_advantages") and isinstance(self._advantages, TokenDataReader): + metadata["advantages"] = self._advantages.get_split(begin_index, end_index) + if hasattr(self, "_old_log_probabilities") and isinstance(self._old_log_probabilities, TokenDataReader): + metadata["old_log_probabilities"] = self._old_log_probabilities.get_split(begin_index, end_index) return begin_index, end_index, metadata @@ -379,6 +442,10 @@ def __enter__(self): self._rejected_spans_writer = RangeWriter(self._path.joinpath("rejected_spans")).__enter__() if self._preprocessing_config.use_image_patches: self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() + if self._preprocessing_config.use_advantages: + self._advantages_writer = PatchWriter(self._path.joinpath("advantages")).__enter__() + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer = PatchWriter(self._path.joinpath("old_log_probabilities")).__enter__() return self def write(self, document: LanguageModelSample): @@ -403,6 +470,14 @@ def write(self, document: LanguageModelSample): assert document.image_patches is not None self._image_patches_writer.write(document.image_patches) + if self._preprocessing_config.use_advantages: + assert document.advantages is not None + self._advantages_writer.write(document.advantages) + + if self._preprocessing_config.use_old_log_probabilities: + assert document.old_log_probabilities is not None + self._old_log_probabilities_writer.write(document.old_log_probabilities) + def __exit__(self, exc_type, exc_val, exc_tb): self._token_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_loss_masking_spans: @@ -412,6 +487,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._rejected_spans_writer.__exit__(exc_type, exc_val, exc_tb) if self._preprocessing_config.use_image_patches: self._image_patches_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_advantages: + self._advantages_writer.__exit__(exc_type, exc_val, exc_tb) + if self._preprocessing_config.use_old_log_probabilities: + self._old_log_probabilities_writer.__exit__(exc_type, exc_val, exc_tb) if exc_type is None: # A dummy config so we can verify the begin and end offsets. @@ -475,6 +554,16 @@ def _get_config(self, begin: int, end: int | None): offset = image_patches.end else: image_patches = NullReaderConfig() + if self._preprocessing_config.use_advantages: + advantages = self._advantages_writer.get_config(offset) + offset = advantages.end + else: + advantages = NullReaderConfig() + if self._preprocessing_config.use_old_log_probabilities: + old_log_probabilities = self._old_log_probabilities_writer.get_config(offset) + offset = old_log_probabilities.end + else: + old_log_probabilities = NullReaderConfig() if end is None: end = offset + len(LanguageModelReaderConfig.footer) @@ -488,6 +577,8 @@ def _get_config(self, begin: int, end: int | None): rejected_spans=rejected_spans, image_patches=image_patches, preprocessing=self._preprocessing_config, + advantages=advantages, + old_log_probabilities=old_log_probabilities, ) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 7ae537104..32ea60cb8 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -85,7 +85,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return PatchSample( + return self.__class__( self.patches.new_empty((0, *self.patches.shape[1:])), self.token_map.new_empty(0), self.positions.new_empty([0, self.patches.ndim - 2]), diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 53683342a..f57ee04d9 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -52,7 +52,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return RangeSample([], size) + return self.__class__([], size) class RangeBatch(Batch): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index cd4d7fa02..6ab55dbba 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -58,7 +58,7 @@ def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) class TokenBatch(Batch): diff --git a/fast_llm/data/sample/token_data.py b/fast_llm/data/sample/token_data.py new file mode 100644 index 000000000..6d2a6f9d1 --- /dev/null +++ b/fast_llm/data/sample/token_data.py @@ -0,0 +1,190 @@ +import functools +import math +import typing + +import numpy as np +import torch + +from fast_llm.config import Field, config_class +from fast_llm.data.preprocessing.abstract import PreprocessingConfig +from fast_llm.data.sample.abstract import ( + Batch, + MemmapIndexedDatasetReader, + MemmapReaderBase, + MemmapReaderBaseConfig, + MemmapReaderConfig, + MemmapWriter, + Sample, +) +from fast_llm.data.sample.patch import PatchReaderBaseConfig +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.utils import Assert, get_unique + + +class TokenDataSample(Sample): + """ + A reusable component holding tensor-valued data of fixed dtype and shape for each token. + TODO: Use as base class for `TokenSample` and `PatchSample`? + """ + + def __init__(self, data: torch.Tensor): + self.data = data + + @classmethod + def from_documents(cls, documents: typing.Iterable[typing.Self]) -> typing.Self: + return cls(torch.cat([document.data for document in documents])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[begin:end]) + + def __len__(self) -> int: + return len(self.data) + + def get_padding(self, size: int) -> typing.Self: + return self.__class__(torch.full([size], 0, dtype=self.data.dtype)) + + +class TokenDataBatch(Batch): + def __init__(self, data: torch.Tensor) -> None: + self.data = data + + @classmethod + def from_samples(cls, samples: typing.Iterable[TokenDataSample]) -> typing.Self: + return cls(torch.stack([sample.data for sample in samples])) + + def crop(self, begin: int, end: int) -> typing.Self: + return self.__class__(self.data[:, begin:end]) + + def to_device_(self, device: "torch.device | str"): + self.data = self.data.to(device, non_blocking=True) + + +@config_class(dynamic_type={MemmapReaderBaseConfig: "token_data"}) +class TokenDataReaderConfig(MemmapReaderConfig): + _abstract = False + header: typing.ClassVar[bytes] = b"token data begin" + footer: typing.ClassVar[bytes] = b"token data end" + num_documents: int = Field() + num_tokens: int = Field() + shape: tuple[int, ...] = Field() + data_type: DataType = Field() + + def __len__(self) -> int: + return self.num_documents + + @functools.cached_property + def size(self) -> int: + return math.prod(self.shape) + + @property + def reader_class(self) -> "type[TokenDataReader]": + return TokenDataReader + + @property + def writer_class(self) -> "type[TokenDataWriter]": + return TokenDataWriter + + @property + def _expected_buffer_size(self) -> int: + return ( + self.num_tokens * self.data_type.torch.itemsize * self.size + + (self.num_documents + 1) * torch.int64.itemsize + ) + + def get_metadata(self) -> dict[str, typing.Any]: + return { + "num_tokens": self.num_tokens, + "num_documents": self.num_documents, + "data_type": str(self.data_type), + "shape": self.shape, + } + + @classmethod + def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typing.Any]: + return { + "num_tokens": sum(metadata_["num_tokens"] for metadata_ in metadata), + "num_documents": sum(metadata_["num_documents"] for metadata_ in metadata), + "data_type": get_unique(metadata_["data_type"] for metadata_ in metadata), + "shape": get_unique(metadata_["shape"] for metadata_ in metadata), + } + + +class TokenDataReader[ConfigType: TokenDataReaderConfig](MemmapIndexedDatasetReader[ConfigType]): + def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): + super().__init__(config, buffer, model_preprocessing) + self._data = torch.frombuffer( + self._buffer, + dtype=self._config.data_type.torch, + count=self._config.num_tokens * self._config.size, + ).view(-1, *self._config.shape) + self._size_cumsums = torch.frombuffer( + self._buffer, dtype=torch.int64, count=self._config.num_documents + 1, offset=self._data.nbytes + ) + + def get_document(self, index: int, begin: int, end: int) -> Sample: + begin_ = self._size_cumsums[index].item() + return TokenDataSample(self._data[begin_ + begin : begin_ + end]) + + def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: + Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) + begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) + end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + + return ( + begin_index, + end_index, + { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + }, + ) + + +class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): + def get_document(self, index: int, begin: int, end: int) -> Sample: + # TODO: Does this make sense? + return TokenDataSample(torch.zeros(end - begin, *self._config.shape, dtype=self._config.data_type.torch)) + + +def _get_nearest_split(cumsum: torch.Tensor, value: float) -> int: + left = torch.searchsorted(cumsum, value, side="right") + if left == len(cumsum): + return left.item() + return left.item() + 1 if (value - cumsum[left]) / (cumsum[left + 1] - cumsum[left]) > 0.5 else left.item() + + +class TokenDataWriter(MemmapWriter): + def __enter__(self): + super().__enter__() + self._size_cumsum = [0] + self._data_type = None + return self + + def write(self, document: TokenDataSample): + super().write(document) + if self._data_type is None: + self._data_type = document.data.dtype + else: + Assert.eq(self._data_type, document.data.dtype) + self._stream.write(document.data.numpy().tobytes()) + self._size_cumsum.append(self._size_cumsum[-1] + len(document.data)) + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None: + self._stream.write(np.array(self._size_cumsum, dtype=np.int64).tobytes(order="C")) + super().__exit__(exc_type, exc_val, exc_tb) + + @classmethod + def _get_config_class(cls) -> type[TokenDataReaderConfig]: + return TokenDataReaderConfig + + def _get_config(self, begin: int, end: int): + return TokenDataReaderConfig( + begin=begin, + end=end, + num_documents=len(self._size_cumsum) - 1, + num_tokens=self._size_cumsum[-1], + data_type=DataType.from_torch(self._data_type), + preprocessing=self._preprocessing_config, + ) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index d8953488f..8ff00db6b 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -12,8 +12,8 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.gpt.config import GPTDataConfig from fast_llm.data.data.gpt.data import GPTData -from fast_llm.data.dataset.config import RedisConfig, SamplingParameters, StreamingDatasetConfig -from fast_llm.data.dataset.streaming import RedisStreamingDataset +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, RedisConfig, SamplingParameters, StreamingDatasetConfig +from fast_llm.data.dataset.streaming import RedisDocument, RedisStreamingDataset from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames @@ -21,7 +21,7 @@ from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert from tests.conftest import WorkerResources -from tests.utils.redis import make_sampling, push_msg, redis_batch_producer +from tests.utils.redis import make_sampling, push_message, redis_batch_producer from tests.utils.subtest import DistributedTestContext from tests.utils.utils import requires_cuda @@ -40,31 +40,63 @@ def fake_redis(monkeypatch): @pytest.mark.parametrize( - "messages", + "documents", [ - (range(3),), - (range(3), range(3, 7)), - (range(3), range(5), [], [9, 4]), + ([0, 1, 2],), + ([0, 1, 2], [3, 4, 5, 6]), + ([0, 1, 2], [0, 1, 2, 3, 4], [9, 4]), + ( + {"tokens": [0, 1, 2], "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, + {"tokens": [5, 3, 2, 0], "loss_masking_spans": [(0, 1), (2, 3)]}, + {"tokens": [5, 3, 2, 0], "chosen_span": (0, 2), "rejected_span": (2, 3)}, + ), ], ) def test_streaming_dataset( fake_redis: fakeredis.FakeRedis, - messages: tuple[list[int], ...], + documents: tuple[list[int] | dict[str, typing.Any], ...], worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) - for message in messages: - push_msg(fake_redis, list(message)) - for message in messages: + documents = [document if isinstance(document, dict) else {"tokens": document} for document in documents] + for document in documents: + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict(document).to_message()) + for document in documents: sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) - Assert.eq(sample.tokens.tokens.tolist(), list(message)) - Assert.eq(sample.tokens.lengths, [len(message)]) - assert sample.loss_masking_spans is None - assert sample.chosen_spans is None - assert sample.rejected_spans is None + Assert.eq(sample.tokens.tokens.tolist(), document["tokens"]) + Assert.eq(sample.tokens.lengths, [len(document["tokens"])]) + + if "loss_masking_spans" in document: + Assert.eq(sample.loss_masking_spans.ranges, document["loss_masking_spans"]) + else: + assert sample.loss_masking_spans is None + + if "chosen_span" in document: + Assert.eq(sample.chosen_spans.ranges, [document["chosen_span"]]) + else: + assert sample.chosen_spans is None + + if "rejected_span" in document: + Assert.eq(sample.rejected_spans.ranges, [document["rejected_span"]]) + else: + assert sample.rejected_spans is None + + assert sample.image_patches is None + + if "advantage" in document: + Assert.rms_close( + sample.advantages.data, torch.full([len(document["tokens"])], document["advantage"]), 1e-8 + ) + else: + assert sample.advantages is None + + if "old_log_probabilities" in document: + Assert.rms_close(sample.old_log_probabilities.data, torch.tensor(document["old_log_probabilities"]), 1e-8) + else: + assert sample.old_log_probabilities is None @pytest.mark.parametrize( @@ -95,12 +127,12 @@ def test_streaming_sampled_dataset( ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) - distributed = Distributed(DistributedConfig(), use_cpu=True) + distributed = Distributed(DistributedConfig(use_cuda=False)) dataset_iterator = iter( RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) ) for message in messages: - push_msg(fake_redis, list(message)) + push_message(fake_redis, {"tokens": list(message)}) for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index 591ee74e6..ec8adde8a 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -1,6 +1,5 @@ import contextlib import itertools -import json import pathlib import socket import threading @@ -10,7 +9,6 @@ from fast_llm.data.dataset.config import ( REDIS_DATA_STREAM, - REDIS_FIELD, REDIS_GROUP_NAME, RedisConfig, SamplingConfig, @@ -29,9 +27,9 @@ def find_free_port(): return s.getsockname()[1] -def push_msg(redis_client, tokens): +def push_message(redis_client, message): """Push a message into FakeRedis stream.""" - redis_client.xadd(REDIS_DATA_STREAM, {REDIS_FIELD: json.dumps({"tokens": tokens, "tokens_dtype": "int64"})}) + redis_client.xadd(REDIS_DATA_STREAM, message) def wait_until_stream_empty( @@ -76,7 +74,7 @@ def producer_loop(): for sample_index in itertools.count(): if stop_event.is_set(): break - push_msg(client, [sample_index] * batch_config.sequence_length) + push_message(client, {"tokens": [sample_index] * batch_config.sequence_length}) if sample_index % 5 == 0: wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event) From ef6a2bb0e3d0f12adb63588fa6c1f2d7f2470f66 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 23 Jan 2026 15:59:25 -0500 Subject: [PATCH 2/4] fixes --- fast_llm/core/distributed.py | 28 +++++++++++------- tests/data/test_streaming.py | 42 +++++++++++++++------------ tests/functional/test_entropy_loss.py | 4 ++- tests/utils/redis.py | 11 ++++--- tests/utils/subtest.py | 17 ++++------- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index da443c4f6..0c9b372f0 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) +def _get_device(group: ProcessGroup) -> torch.device: + if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL): + return torch.device(torch.cuda.current_device()) + elif isinstance(group, torch.distributed.ProcessGroupGloo): + return torch.device("cpu") + else: + raise NotImplementedError(type(group)) + + @contextlib.contextmanager def set_timeout(group: ProcessGroup | None, timeout: float | None = None): if group is not None and timeout is not None: @@ -72,12 +81,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier( - group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None -) -> None: +def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -88,10 +95,9 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, - device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() @@ -106,7 +112,7 @@ def all_gather_scalar( timeout: float | None = None, ): if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) output_tensor = value.new_empty((group.size(),)) with set_timeout(group, timeout): torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) @@ -116,7 +122,7 @@ def all_gather_scalar( def broadcast_scalar( - value: float | int, + value: float | int | None, dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, src: int = 0, @@ -124,7 +130,7 @@ def broadcast_scalar( ) -> float | int: if not group: return value - tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device())) + tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group))) if group.rank() == src: tensor.fill_(value) broadcast(tensor, src, group, timeout=timeout) @@ -141,14 +147,14 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None if group.rank() == src: tensor = _object_to_tensor(input_object) size = tensor.numel() - broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast_tensor.copy_(tensor) broadcast_scalar(size, torch.int64, group, src) broadcast(broadcast_tensor, src, group) return input_object else: size = int(broadcast_scalar(None, torch.int64, group, src)) - output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast(output_tensor, src, group) return _tensor_to_object(output_tensor) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 8ff00db6b..05bd05285 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -16,14 +16,13 @@ from fast_llm.data.dataset.streaming import RedisDocument, RedisStreamingDataset from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample -from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.engine.distributed.config import DistributedBackend, DistributedConfig, DistributedDimNames from fast_llm.engine.distributed.distributed import Distributed from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert from tests.conftest import WorkerResources -from tests.utils.redis import make_sampling, push_message, redis_batch_producer +from tests.utils.redis import make_sampling, redis_batch_producer from tests.utils.subtest import DistributedTestContext -from tests.utils.utils import requires_cuda logger = logging.getLogger(__name__) @@ -42,13 +41,13 @@ def fake_redis(monkeypatch): @pytest.mark.parametrize( "documents", [ - ([0, 1, 2],), - ([0, 1, 2], [3, 4, 5, 6]), - ([0, 1, 2], [0, 1, 2, 3, 4], [9, 4]), + (range(3),), + (range(3), range(3, 6)), + (range(3), range(5), [9, 4]), ( - {"tokens": [0, 1, 2], "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, - {"tokens": [5, 3, 2, 0], "loss_masking_spans": [(0, 1), (2, 3)]}, - {"tokens": [5, 3, 2, 0], "chosen_span": (0, 2), "rejected_span": (2, 3)}, + {"tokens": list(range(3)), "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, + {"tokens": list(range(5)), "loss_masking_spans": [(0, 1), (2, 3)]}, + {"tokens": list(range(8)), "chosen_span": (0, 2), "rejected_span": (3, 5)}, ), ], ) @@ -60,7 +59,7 @@ def test_streaming_dataset( """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) - documents = [document if isinstance(document, dict) else {"tokens": document} for document in documents] + documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents] for document in documents: fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict(document).to_message()) for document in documents: @@ -132,7 +131,7 @@ def test_streaming_sampled_dataset( RedisStreamingDataset(stream_config, distributed.config).sample(make_sampling(5, 1, distributed)) ) for message in messages: - push_message(fake_redis, {"tokens": list(message)}) + fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict({"tokens": list(message)}).to_message()) for expected_sample, expected_lengths_ in zip(expected_samples, expected_lengths, strict=True): sample = next(dataset_iterator) assert isinstance(sample, LanguageModelSample) @@ -150,7 +149,13 @@ def _get_distributed_and_batch_config( distributed_config_dict: dict[str, typing.Any], world_size: int = 1 ) -> tuple[DistributedConfig, GPTBatchConfig]: distributed_config = DistributedConfig.from_dict( - distributed_config_dict, {"world_size": world_size, "local_world_size": world_size} + distributed_config_dict, + { + "world_size": world_size, + "local_world_size": world_size, + "use_cuda": False, + "backend": DistributedBackend.gloo, + }, ) with NoAutoValidate(): batch_config = GPTBatchConfig(micro_batch_size=2, sequence_length=10) @@ -221,14 +226,15 @@ def _run_test_data_streaming_distributed( # Import all dynamic classes. TODO: needed? import fast_llm.cli # noqa + print(_DISTRIBUTED_TESTING_CONFIGS) for name, num_gpus, distributed_config_dict in _DISTRIBUTED_TESTING_CONFIGS: with test_context.subtest(base_path, name, num_gpus) as subtest: + print(name, subtest.do_run) if subtest.do_run: distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) _run_test_data_streaming(base_path / name, distributed_config, batch_config, port) -@requires_cuda def test_data_streaming(result_path, worker_resources): distributed_config, batch_config = _get_distributed_and_batch_config({}) path = result_path / "data_streaming/single_gpu" @@ -250,24 +256,22 @@ def test_data_streaming(result_path, worker_resources): ] -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) def test_run_data_streaming_distributed(run_parallel_script, result_path, worker_resources): - if torch.cuda.device_count() < 2: - pytest.skip(f"Not enough GPUs") run_parallel_script( _run_test_data_streaming_distributed, (result_path / "data_streaming", worker_resources.torchrun_port), - world_size=torch.cuda.device_count(), + world_size=4, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. ) -@requires_cuda @pytest.mark.slow @pytest.mark.depends_on(on=["test_data_streaming"]) @pytest.mark.parametrize(("name", "num_gpus", "distributed_config_dict"), _DISTRIBUTED_TESTING_CONFIGS) def test_data_streaming_distributed(result_path, name, num_gpus, distributed_config_dict, report_subtest): - report_subtest(path := result_path / f"data_streaming/{name}", num_gpus) + report_subtest(path := result_path / f"data_streaming/{name}", num_gpus, use_cuda=False) distributed_config, batch_config = _get_distributed_and_batch_config(distributed_config_dict, num_gpus) check_data_streaming_results(path, distributed_config, batch_config) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py index 9c06c1919..db53fe0d9 100644 --- a/tests/functional/test_entropy_loss.py +++ b/tests/functional/test_entropy_loss.py @@ -176,4 +176,6 @@ def test_run_entropy_loss_distributed(run_parallel_script, result_path): def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: pytest.skip(reason="Not implemented") - report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) + report_subtest( + result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2, use_cuda=False + ) diff --git a/tests/utils/redis.py b/tests/utils/redis.py index ec8adde8a..198c6df78 100644 --- a/tests/utils/redis.py +++ b/tests/utils/redis.py @@ -16,6 +16,7 @@ SamplingParameters, StreamingDatasetConfig, ) +from fast_llm.data.dataset.streaming import RedisDocument from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.models.gpt.config import GPTBatchConfig @@ -27,11 +28,6 @@ def find_free_port(): return s.getsockname()[1] -def push_message(redis_client, message): - """Push a message into FakeRedis stream.""" - redis_client.xadd(REDIS_DATA_STREAM, message) - - def wait_until_stream_empty( redis_client, stream_key, @@ -74,7 +70,10 @@ def producer_loop(): for sample_index in itertools.count(): if stop_event.is_set(): break - push_message(client, {"tokens": [sample_index] * batch_config.sequence_length}) + client.xadd( + REDIS_DATA_STREAM, + RedisDocument.from_dict({"tokens": [sample_index] * batch_config.sequence_length}).to_message(), + ) if sample_index % 5 == 0: wait_until_stream_empty(client, REDIS_DATA_STREAM, REDIS_GROUP_NAME, stop_event) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index 3ca84499e..e5c87f9f5 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -51,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start", device=self._pool.device) + safe_barrier(self._group, "start") return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end", device=self._pool.device) + safe_barrier(self._group, "testing end") # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -138,13 +138,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name, device=self._test_context._pool.device) - self._success = ( - allreduce_scalar( - self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device - ) - == group.size() - ) + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. @@ -201,8 +196,8 @@ def report_subtest(request: pytest.FixtureRequest): verbose = request.config.getoption("verbose") do_capture = request.config.getoption("distributed_capture") - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: - if torch.cuda.device_count() < world_size: + def do_report_subtest(path: pathlib.Path, world_size: int, use_cuda: bool = True) -> None: + if use_cuda and torch.cuda.device_count() < world_size: pytest.skip(f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {world_size}") success = check_subtest_success(path) if not do_capture: From 9c2ec9f975974d27186acdfff6da085657034551 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 26 Jan 2026 16:32:13 -0500 Subject: [PATCH 3/4] stuff --- fast_llm/data/dataset/abstract.py | 2 +- fast_llm/data/dataset/sampled.py | 11 ++++--- fast_llm/data/dataset/streaming.py | 30 +++++++++++-------- fast_llm/data/preparator/gpt_memmap/config.py | 15 ++++++++++ .../data/preparator/gpt_memmap/prepare.py | 11 ++++++- fast_llm/data/preprocessing/language_model.py | 11 +++++-- fast_llm/models/gpt/config.py | 10 +++++++ fast_llm/models/gpt/trainer.py | 5 ++-- tests/data/test_streaming.py | 27 ++++++++++++----- tests/utils/dataset.py | 9 ++++++ tests/utils/model_configs.py | 17 +++++++++++ 11 files changed, 115 insertions(+), 33 deletions(-) diff --git a/fast_llm/data/dataset/abstract.py b/fast_llm/data/dataset/abstract.py index 1df24e92b..520e6d0af 100644 --- a/fast_llm/data/dataset/abstract.py +++ b/fast_llm/data/dataset/abstract.py @@ -61,7 +61,7 @@ def sample(self, config: "SamplingData") -> SampledDataset[SampleType]: class SamplableIterableDataset[SampleType: Sample](SamplableDataset[SampleType]): @abc.abstractmethod - def __iter__(self) -> typing.Iterator[SampleType]: + def iterate(self, sampling: "SamplingData") -> typing.Iterator[SampleType]: pass def sample(self, config: "SamplingData") -> "SampledIterableDataset[SampleType]": diff --git a/fast_llm/data/dataset/sampled.py b/fast_llm/data/dataset/sampled.py index ca76db949..35504df9f 100644 --- a/fast_llm/data/dataset/sampled.py +++ b/fast_llm/data/dataset/sampled.py @@ -442,11 +442,10 @@ def __init__( sampling: SamplingData, ): self._dataset = dataset - self._config = sampling.config - self._parameters = sampling.parameters + self._sampling = sampling self._documents: list[SampleType] = [] self._current_length = 0 - self._sample_length = self._parameters.sequence_length + self._parameters.extra_tokens + self._sample_length = self._sampling.parameters.sequence_length + self._sampling.parameters.extra_tokens # Delay iterator creation to avoid pickling issues. self._iterator: typing.Iterator[SampleType] | None = None @@ -458,7 +457,7 @@ def requires_broadcast(self) -> bool: def __getitem__(self, index: int) -> SampleType: if self._iterator is None: - self._iterator = iter(self._dataset) + self._iterator = self._dataset.iterate(self._sampling) while self._current_length < self._sample_length: document = next(self._iterator) if len(document) > self._sample_length: @@ -474,7 +473,7 @@ def __getitem__(self, index: int) -> SampleType: else: last_length = len(self._documents[-1]) remaining_length = last_length - (self._current_length - self._sample_length) - if self._parameters.truncate_documents: + if self._sampling.parameters.truncate_documents: documents = self._documents[:-1] + [self._documents[-1].crop(0, remaining_length)] self._documents = [self._documents[-1].crop(remaining_length, last_length)] else: @@ -486,7 +485,7 @@ def __getitem__(self, index: int) -> SampleType: return sample def __len__(self) -> int: - return self._parameters.num_samples + return self._sampling.parameters.num_samples @property def name(self) -> str: diff --git a/fast_llm/data/dataset/streaming.py b/fast_llm/data/dataset/streaming.py index 48275988f..8e3719986 100644 --- a/fast_llm/data/dataset/streaming.py +++ b/fast_llm/data/dataset/streaming.py @@ -7,7 +7,8 @@ from fast_llm.config import Config, Configurable, Field, config_class from fast_llm.data.dataset.abstract import SamplableIterableDataset -from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, StreamingDatasetConfig +from fast_llm.data.dataset.config import REDIS_DATA_STREAM, REDIS_GROUP_NAME, SamplingData, StreamingDatasetConfig +from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig from fast_llm.data.sample.language_model import LanguageModelSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample @@ -83,24 +84,27 @@ def to_message(self) -> dict[str, str | int | float | bytes]: message["data"] = json.dumps(data) return message - def to_sample(self): + def to_sample(self, preprocessing: LanguageModelPreprocessingConfig | None): sample_size = len(self.tokens) + # TODO: Check explicitly that required data is available? return LanguageModelSample( tokens=TokenSample(self.tokens, [sample_size]), loss_masking_spans=( - None - if self.loss_masking_spans is None - else RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size) + RangeSample([(begin, end) for begin, end in self.loss_masking_spans], sample_size) + if preprocessing.use_loss_masking_spans + else None + ), + chosen_spans=RangeSample([self.chosen_span], sample_size) if preprocessing.use_preference_spans else None, + rejected_spans=( + RangeSample([self.rejected_span], sample_size) if preprocessing.use_preference_spans else None ), - chosen_spans=None if self.chosen_span is None else RangeSample([self.chosen_span], sample_size), - rejected_spans=None if self.rejected_span is None else RangeSample([self.rejected_span], sample_size), advantages=( - None - if self.advantage is None - else TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32)) + TokenDataSample(torch.full([sample_size], self.advantage, dtype=torch.float32)) + if preprocessing.use_advantages + else None ), old_log_probabilities=( - None if self.old_log_probabilities is None else TokenDataSample(self.old_log_probabilities) + TokenDataSample(self.old_log_probabilities) if preprocessing.use_old_log_probabilities else None ), ) @@ -125,7 +129,7 @@ def requires_broadcast(self) -> bool: def name(self) -> str: return self._name - def __iter__(self) -> typing.Iterator[LanguageModelSample]: + def iterate(self, sampling: SamplingData) -> typing.Iterator[SampleType]: worker_info = torch.utils.data.get_worker_info() if worker_info is not None and worker_info.num_workers > 1: raise RuntimeError("StreamingDataset can work only with one instance per rank") @@ -166,4 +170,4 @@ def __iter__(self) -> typing.Iterator[LanguageModelSample]: assert stream_key == REDIS_DATA_STREAM.encode() for message_id, message in messages_: print(message) - yield RedisDocument.from_message(message).to_sample() + yield RedisDocument.from_message(message).to_sample(sampling.preprocessing) diff --git a/fast_llm/data/preparator/gpt_memmap/config.py b/fast_llm/data/preparator/gpt_memmap/config.py index a1aadf40a..1645e4dea 100644 --- a/fast_llm/data/preparator/gpt_memmap/config.py +++ b/fast_llm/data/preparator/gpt_memmap/config.py @@ -48,6 +48,10 @@ def has_preference_spans(self) -> bool: def has_images(self) -> bool: return False + @functools.cached_property + def has_grpo_data(self) -> bool: + return False + @config_class(dynamic_type={LanguageModelSourceConfig: "document"}) class DocumentSourceConfig(LanguageModelSourceConfig): @@ -91,6 +95,13 @@ class DocumentSourceConfig(LanguageModelSourceConfig): desc="Field containing image positions in the text.", hint=FieldHint.optional, ) + # TODO: Old log probabilities are made up (zeros) since we don't know the token count in advance. + advantages: str | None = Field( + default=None, + desc="Field containing advantaged for policy optimization." + " Mainly for debugging purposed as advantages are typically generated at runtime.", + hint=FieldHint.optional, + ) @functools.cached_property def columns(self) -> list[str]: @@ -117,6 +128,10 @@ def has_images(self) -> bool: Assert.eq(self.images is None, self.image_positions is None) return self.images is not None + @functools.cached_property + def has_grpo_data(self) -> bool: + return self.advantages is not None + def _validate(self): super()._validate() if self.has_preference_spans and self.has_loss_masking_span: diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index 325d33c43..a45766ab3 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -42,6 +42,7 @@ from fast_llm.data.sample.patch import PatchSample from fast_llm.data.sample.range import RangeSample from fast_llm.data.sample.token import TokenSample +from fast_llm.data.sample.token_data import TokenDataSample from fast_llm.engine.config_utils.data_type import DataType, get_unsigned_integer_type from fast_llm.engine.config_utils.run import log_main_rank from fast_llm.utils import normalize_probabilities, padded_cumsum @@ -226,7 +227,9 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: token_spans_by_type = collections.defaultdict(list) - image_patches = image_token_maps = image_position_ids = patch_counts = None + image_patches = image_token_maps = image_position_ids = patch_counts = advantages = old_log_probabilities = ( + None + ) if isinstance(self._source_schema, ConversationSourceConfig): # Conversation format: tokenize messages and get loss masking spans from chat template @@ -332,6 +335,10 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: else: raise NotImplementedError(f"Unsupported source schema type: {type(self._source_schema)}") + if self._source_schema.has_grpo_data: + advantages = torch.full_like(tokens, sample[self._source_schema.advantages], dtype=torch.float32) + old_log_probabilities = torch.zeros(tokens, sample[self._source_schema.advantages], dtype=torch.float32) + sample_size = len(tokens) return LanguageModelSample( @@ -357,6 +364,8 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: if self._source_schema.has_images else None ), + TokenDataSample(advantages) if self._source_schema.has_grpo_data else None, + TokenDataSample(old_log_probabilities) if self._source_schema.has_grpo_data else None, ) def generate_config_yaml_for_sharded_dst( diff --git a/fast_llm/data/preprocessing/language_model.py b/fast_llm/data/preprocessing/language_model.py index 87d176663..a4b4a2ff8 100644 --- a/fast_llm/data/preprocessing/language_model.py +++ b/fast_llm/data/preprocessing/language_model.py @@ -22,8 +22,7 @@ class LanguageModelPreprocessingConfig(PreprocessingConfig): vocab_size: int | None = Field(default=None) use_loss_masking_spans: bool = Field(default=False) use_preference_spans: bool = Field(default=False) - use_advantages: bool = Field(default=False) - use_old_log_probabilities: bool = Field(default=False) + use_grpo_data: bool = Field(default=False) def _validate(self) -> None: super()._validate() @@ -34,6 +33,14 @@ def _validate(self) -> None: def use_image_patches(self) -> bool: return isinstance(self.image_patches, ImagePatchConfig) + @functools.cached_property + def use_advantages(self) -> bool: + return self.use_grpo_data + + @functools.cached_property + def use_old_log_probabilities(self) -> bool: + return self.use_grpo_data + def check_compatibility(self, preprocessing: typing.Self) -> None: Assert.custom(isinstance, preprocessing, LanguageModelPreprocessingConfig) # TODO: Check more tokenizer data, ex. bos/eos tokens? path if points to HF hub? diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a315beecc..4e2522fc1 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -54,6 +54,16 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_preference_spans: bool = Field( + default=False, + desc="Read dpo data (chosen and rejected spans) from the dataset.", + hint=FieldHint.feature, + ) + use_grpo_data: bool = Field( + default=False, + desc="Read grpo data (advantages and old log probabilities) from the dataset.", + hint=FieldHint.feature, + ) truncate_documents: bool | None = Field( default=True, desc=( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 768d3fdd7..c07810e34 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,11 +33,12 @@ def _get_sampling_parameters( def _get_preprocessing_config( self, *, _return_dict: bool = False ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = { "type": "language_model", "vocab_size": self._config.model.base_model.embeddings.vocab_size, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - # OK since DPO is not supported for MTP. - "use_preference_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + "use_preference_spans": self._config.batch.use_preference_spans, + "use_grpo_data": self._config.batch.use_grpo_data, } return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/tests/data/test_streaming.py b/tests/data/test_streaming.py index 05bd05285..a39228f44 100644 --- a/tests/data/test_streaming.py +++ b/tests/data/test_streaming.py @@ -21,6 +21,7 @@ from fast_llm.models.gpt.config import GPTBatchConfig from fast_llm.utils import Assert from tests.conftest import WorkerResources +from tests.data.common import get_sampling_data from tests.utils.redis import make_sampling, redis_batch_producer from tests.utils.subtest import DistributedTestContext @@ -39,26 +40,36 @@ def fake_redis(monkeypatch): @pytest.mark.parametrize( - "documents", + ("documents", "preprocessing"), [ - (range(3),), - (range(3), range(3, 6)), - (range(3), range(5), [9, 4]), + ((range(3),), {}), + ((range(3), range(3, 6)), {}), + ((range(3), range(5), [9, 4]), {}), + (({"tokens": list(range(5)), "loss_masking_spans": [(0, 1), (2, 3)]},), {"use_loss_masking_spans": True}), ( - {"tokens": list(range(3)), "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, - {"tokens": list(range(5)), "loss_masking_spans": [(0, 1), (2, 3)]}, - {"tokens": list(range(8)), "chosen_span": (0, 2), "rejected_span": (3, 5)}, + ({"tokens": list(range(8)), "chosen_span": (0, 2), "rejected_span": (3, 5)},), + {"use_preference_spans": True}, + ), + ( + ( + {"tokens": list(range(3)), "advantage": 0.33, "old_log_probabilities": [0.25, -0.52, 0.99]}, + {"tokens": list(range(4)), "advantage": 0.7, "old_log_probabilities": [1, 2, 3, 4]}, + ), + {"use_grpo_data": True}, ), ], ) def test_streaming_dataset( fake_redis: fakeredis.FakeRedis, documents: tuple[list[int] | dict[str, typing.Any], ...], + preprocessing: dict, worker_resources: WorkerResources, ): """StreamingDataset should read a message and convert it into LanguageModelSample.""" stream_config = StreamingDatasetConfig(port=worker_resources.torchrun_port) - dataset_iterator = iter(RedisStreamingDataset(stream_config, DistributedConfig())) + dataset_iterator = RedisStreamingDataset(stream_config, DistributedConfig()).iterate( + get_sampling_data(len(documents), preprocessing=LanguageModelPreprocessingConfig.from_dict(preprocessing)) + ) documents = [document if isinstance(document, dict) else {"tokens": list(document)} for document in documents] for document in documents: fake_redis.xadd(REDIS_DATA_STREAM, RedisDocument.from_dict(document).to_message()) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 854ecec36..a1aa3357a 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -129,6 +129,7 @@ def _get_hf_test_dataset( min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, + has_grpo_data: bool = False, min_images: int = 0, max_images: int = 0, min_image_size: int = 4, @@ -153,6 +154,9 @@ def _get_hf_test_dataset( document_sizes, min_images, max_images, min_image_size, max_image_size, random_state ) + if has_grpo_data: + dataset_dict["advantages"] = random_state.randn(num_documents).tolist() + return datasets.Dataset.from_dict(dataset_dict) @@ -168,6 +172,7 @@ def _get_test_dataset( min_loss_masking_spans: int = 0, max_loss_masking_spans: int = 0, has_preference_spans: bool = False, + has_grpo_data: bool = False, splits: dict[str, float] | None = None, min_images: int = 0, max_images: int = 0, @@ -192,6 +197,7 @@ def _get_test_dataset( min_loss_masking_spans=min_loss_masking_spans, max_loss_masking_spans=max_loss_masking_spans, has_preference_spans=has_preference_spans, + has_grpo_data=has_grpo_data, min_images=min_images, max_images=max_images, min_image_size=min_image_size, @@ -207,6 +213,8 @@ def _get_test_dataset( if max_images > 0: source_schema["images"] = "images" source_schema["image_positions"] = "image_positions" + if has_grpo_data: + source_schema["advantages"] = "advantages" download_santacoder_tokenizer() preparator_config = GPTMemmapDatasetPreparatorConfig.from_dict( @@ -239,6 +247,7 @@ def _get_test_dataset( vocab_size=max_vocab_size, use_loss_masking_spans=max_loss_masking_spans > 0, use_preference_spans=has_preference_spans, + use_grpo_data=has_grpo_data, ) return path, config, hf_path, preprocessing diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index f67a52939..19703a1c6 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -735,6 +735,23 @@ def update_and_add_testing_config( auto_model_class=transformers.AutoModelForImageTextToText, ) +update_and_add_testing_config( + # Tests mixture of experts, mixtral converter. + "llama", + "llama_grpo", + updates={("model", "base_model", "head", "losses"): {"grpo": {"type": "grpo"}}, ("batch", "use_grpo_data"): True}, + # TODO: New base image broke mixtral + groups={ + ModelTestingGroup.basic: ModelTestingGroupAction.normal, + ModelTestingGroup.checkpoint: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.convert: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.generate: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, + ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, + ModelTestingGroup.streaming: ModelTestingGroupAction.normal, + }, +) + update_and_add_testing_config( # Tests hybrid with attention + gated delta net mixer. From af5ce31b26507e9a29b027f072e7983d3f68fece Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Mon, 26 Jan 2026 20:18:53 -0500 Subject: [PATCH 4/4] Fixes --- .../data/preparator/gpt_memmap/prepare.py | 17 ++++++----- fast_llm/data/preprocessing/tokenizer.py | 22 ++++++++++---- fast_llm/data/sample/language_model.py | 28 ++++++++++++++++-- fast_llm/data/sample/token_data.py | 29 +++++++++---------- fast_llm/models/gpt/model.py | 16 ++++++++++ tests/utils/dataset.py | 7 +++-- tests/utils/model_configs.py | 6 ++-- 7 files changed, 90 insertions(+), 35 deletions(-) diff --git a/fast_llm/data/preparator/gpt_memmap/prepare.py b/fast_llm/data/preparator/gpt_memmap/prepare.py index a45766ab3..7fc2e35af 100644 --- a/fast_llm/data/preparator/gpt_memmap/prepare.py +++ b/fast_llm/data/preparator/gpt_memmap/prepare.py @@ -223,6 +223,7 @@ def _preprocessing_config(self) -> LanguageModelPreprocessingConfig: ), use_loss_masking_spans=self._source_schema.has_loss_masking_span, use_preference_spans=self._source_schema.has_preference_spans, + use_grpo_data=self._source_schema.has_grpo_data, ) def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: @@ -337,35 +338,37 @@ def _prepare_sample(self, sample: dict[str, typing.Any]) -> LanguageModelSample: if self._source_schema.has_grpo_data: advantages = torch.full_like(tokens, sample[self._source_schema.advantages], dtype=torch.float32) - old_log_probabilities = torch.zeros(tokens, sample[self._source_schema.advantages], dtype=torch.float32) + old_log_probabilities = torch.zeros_like(tokens, dtype=torch.float32) sample_size = len(tokens) return LanguageModelSample( TokenSample(tokens, [sample_size]), - ( + loss_masking_spans=( RangeSample(token_spans_by_type[SpanType.loss_masking], sample_size) if self._source_schema.has_loss_masking_span else None ), - ( + chosen_spans=( RangeSample(token_spans_by_type[SpanType.chosen], sample_size) if self._source_schema.has_preference_spans else None ), - ( + rejected_spans=( # `tokenize_with_spans` excludes the final eod token from the rejected span, but we want to include it. RangeSample([(begin, end + 1) for begin, end in token_spans_by_type[SpanType.rejected]], sample_size) if self._source_schema.has_preference_spans else None ), - ( + image_patches=( PatchSample(image_patches, image_token_maps, image_position_ids, sample_size, patch_counts) if self._source_schema.has_images else None ), - TokenDataSample(advantages) if self._source_schema.has_grpo_data else None, - TokenDataSample(old_log_probabilities) if self._source_schema.has_grpo_data else None, + advantages=TokenDataSample(advantages) if self._source_schema.has_grpo_data else None, + old_log_probabilities=( + TokenDataSample(old_log_probabilities) if self._source_schema.has_grpo_data else None + ), ) def generate_config_yaml_for_sharded_dst( diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 157744f51..4408ca772 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: import numpy as np import torch + import transformers @config_class(dynamic_type={PreprocessingConfig: "tokenizer"}) @@ -52,7 +53,7 @@ def __init__(self, config: ConfigType): from transformers import AutoTokenizer log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: "transformers.PreTrainedTokenizer" = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=self._config.path, errors="replace", max_len=None, @@ -70,10 +71,15 @@ def __init__(self, config: ConfigType): @functools.cached_property def vocab_size(self) -> int: - out = len(self.tokenizer) - if self._config.max_vocab_size is not None: - out = min(out, self._config.max_vocab_size) - return out + return ( + self._tokenizer_vocab_size + if self._config.max_vocab_size is None + else min(self._tokenizer_vocab_size, self._config.max_vocab_size) + ) + + @functools.cached_property + def _tokenizer_vocab_size(self) -> int: + return len(self.tokenizer) @property def vocab(self) -> dict[str, int]: @@ -99,7 +105,11 @@ def tokenize( tokens = ( torch.tensor( tokens, - dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + dtype=( + torch.int64 + if self._tokenizer_vocab_size > torch.iinfo(data_type.torch).max + else data_type.torch + ), ) % self._config.max_vocab_size ).to(data_type.torch) diff --git a/fast_llm/data/sample/language_model.py b/fast_llm/data/sample/language_model.py index e3dab9bc2..aa09e1467 100644 --- a/fast_llm/data/sample/language_model.py +++ b/fast_llm/data/sample/language_model.py @@ -39,7 +39,13 @@ RangeWriter, ) from fast_llm.data.sample.token import TokenBatch, TokenReaderConfig, TokenSample, TokenWriter -from fast_llm.data.sample.token_data import TokenDataBatch, TokenDataReader, TokenDataReaderConfig, TokenDataSample +from fast_llm.data.sample.token_data import ( + TokenDataBatch, + TokenDataReader, + TokenDataReaderConfig, + TokenDataSample, + TokenDataWriter, +) from fast_llm.engine.config_utils.data_type import DataType from fast_llm.utils import Assert @@ -443,9 +449,11 @@ def __enter__(self): if self._preprocessing_config.use_image_patches: self._image_patches_writer = PatchWriter(self._path.joinpath("image_patches")).__enter__() if self._preprocessing_config.use_advantages: - self._advantages_writer = PatchWriter(self._path.joinpath("advantages")).__enter__() + self._advantages_writer = TokenDataWriter(self._path.joinpath("advantages")).__enter__() if self._preprocessing_config.use_old_log_probabilities: - self._old_log_probabilities_writer = PatchWriter(self._path.joinpath("old_log_probabilities")).__enter__() + self._old_log_probabilities_writer = TokenDataWriter( + self._path.joinpath("old_log_probabilities") + ).__enter__() return self def write(self, document: LanguageModelSample): @@ -525,6 +533,20 @@ def __exit__(self, exc_type, exc_val, exc_tb): config.image_patches.begin, config.image_patches.end, ) + if self._preprocessing_config.use_advantages: + _copy_chunked( + self._path.joinpath("advantages"), + self._stream, + config.advantages.begin, + config.advantages.end, + ) + if self._preprocessing_config.use_old_log_probabilities: + _copy_chunked( + self._path.joinpath("old_log_probabilities"), + self._stream, + config.old_log_probabilities.begin, + config.old_log_probabilities.end, + ) self._directory.cleanup() super().__exit__(exc_type, exc_val, exc_tb) diff --git a/fast_llm/data/sample/token_data.py b/fast_llm/data/sample/token_data.py index 6d2a6f9d1..cf094bd4d 100644 --- a/fast_llm/data/sample/token_data.py +++ b/fast_llm/data/sample/token_data.py @@ -9,7 +9,7 @@ from fast_llm.data.preprocessing.abstract import PreprocessingConfig from fast_llm.data.sample.abstract import ( Batch, - MemmapIndexedDatasetReader, + MemmapReader, MemmapReaderBase, MemmapReaderBaseConfig, MemmapReaderConfig, @@ -109,7 +109,7 @@ def blend_metadata(cls, metadata: list[dict[str, typing.Any]]) -> dict[str, typi } -class TokenDataReader[ConfigType: TokenDataReaderConfig](MemmapIndexedDatasetReader[ConfigType]): +class TokenDataReader[ConfigType: TokenDataReaderConfig](MemmapReader[ConfigType]): def __init__(self, config: ConfigType, buffer: memoryview, model_preprocessing: PreprocessingConfig | None = None): super().__init__(config, buffer, model_preprocessing) self._data = torch.frombuffer( @@ -125,20 +125,15 @@ def get_document(self, index: int, begin: int, end: int) -> Sample: begin_ = self._size_cumsums[index].item() return TokenDataSample(self._data[begin_ + begin : begin_ + end]) - def get_split(self, begin_ratio: float, end_ratio: float) -> tuple[int, int, dict[str, typing.Any]]: - Assert.custom(lambda x: x == sorted(x), [0, begin_ratio, end_ratio, 1]) - begin_index = _get_nearest_split(self._size_cumsums[1:], begin_ratio * self.num_tokens) - end_index = _get_nearest_split(self._size_cumsums[1:], end_ratio * self.num_tokens) + def get_split(self, begin_index: int, end_index: int) -> dict[str, typing.Any]: + Assert.custom(lambda x: x == sorted(x), [0, begin_index, end_index, self._config.num_documents]) - return ( - begin_index, - end_index, - { - "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), - "num_documents": end_index - begin_index, - "data_type": str(self._config.data_type), - }, - ) + return { + "num_tokens": self._size_cumsums[end_index].item() - self._size_cumsums[begin_index].item(), + "num_documents": end_index - begin_index, + "data_type": str(self._config.data_type), + "shape": self._config.shape, + } class EmptyPatchReader[ConfigType: PatchReaderBaseConfig](MemmapReaderBase[ConfigType]): @@ -159,14 +154,17 @@ def __enter__(self): super().__enter__() self._size_cumsum = [0] self._data_type = None + self._shape = None return self def write(self, document: TokenDataSample): super().write(document) if self._data_type is None: self._data_type = document.data.dtype + self._shape = document.data.shape[1:] else: Assert.eq(self._data_type, document.data.dtype) + Assert.eq(self._shape, document.data.shape[1:]) self._stream.write(document.data.numpy().tobytes()) self._size_cumsum.append(self._size_cumsum[-1] + len(document.data)) @@ -185,6 +183,7 @@ def _get_config(self, begin: int, end: int): end=end, num_documents=len(self._size_cumsum) - 1, num_tokens=self._size_cumsum[-1], + shape=self._shape, data_type=DataType.from_torch(self._data_type), preprocessing=self._preprocessing_config, ) diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd2932984..eb694a8a6 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -293,6 +293,22 @@ def preprocess_batch( labels_begin, labels_end ).ranges + if batch.advantages is not None: + kwargs[LanguageModelKwargs.advantages] = batch.advantages.crop(labels_begin, labels_end).data + if kwargs[AttentionKwargs.sequence_first]: + kwargs[LanguageModelKwargs.advantages] = ( + kwargs[LanguageModelKwargs.advantages].transpose(0, 1).contiguous() + ) + + if batch.old_log_probabilities is not None: + kwargs[LanguageModelKwargs.old_log_probabilities] = batch.old_log_probabilities.crop( + labels_begin, labels_end + ).data + if kwargs[AttentionKwargs.sequence_first]: + kwargs[LanguageModelKwargs.old_log_probabilities] = ( + kwargs[LanguageModelKwargs.old_log_probabilities].transpose(0, 1).contiguous() + ) + tokens = ( cropped_tokens.tokens.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index a1aa3357a..6456f1589 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -331,9 +331,11 @@ def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, + num_documents=200, max_loss_masking_spans=5, + has_grpo_data=True, max_vocab_size=MODEL_TEST_VOCAB_SIZE, - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 19, "test": 1}, config_only=config_only, ) @@ -342,6 +344,7 @@ def get_multimodal_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset_multimodal", seed=1234, + num_documents=200, max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_images=2, image_patch_config=ImagePatchConfig( @@ -352,6 +355,6 @@ def get_multimodal_test_dataset(config_only: bool = False): image_break_token=None, image_end_token=None, ), - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 19, "test": 1}, config_only=config_only, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 19703a1c6..8c5be2979 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -739,8 +739,10 @@ def update_and_add_testing_config( # Tests mixture of experts, mixtral converter. "llama", "llama_grpo", - updates={("model", "base_model", "head", "losses"): {"grpo": {"type": "grpo"}}, ("batch", "use_grpo_data"): True}, - # TODO: New base image broke mixtral + updates={ + ("model", "base_model", "head", "losses"): {"grpo": {"type": "grpo"}}, + ("batch", "use_grpo_data"): True, + }, groups={ ModelTestingGroup.basic: ModelTestingGroupAction.normal, ModelTestingGroup.checkpoint: ModelTestingGroupAction.not_implemented,