Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 42 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ The library now supports reasoning traces through the `reasoning_content` field
- [Using the library](#using-the-library)
- [Data format](#data-format)
- [Reasoning content support](#reasoning-content-support-1)
- [Continual pretraining mode](#continual-pretraining-mode)
- [Documentation](#documentation)
- [Learning about the training arguments](#learning-about-training-arguments)
- [`TrainingArgs`](#trainingargs)
Expand Down Expand Up @@ -122,6 +123,46 @@ The library now supports an optional `reasoning_content` field in addition to th
}
```

## Continual pretraining mode

In addition to instruction tuning, the library can run document-style continual pretraining on raw text corpora.
Enable this by supplying a block size when invoking `main_ds.py`:

```bash
torchrun main_ds.py \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--data_path /data/documents.jsonl \
--ckpt_output_dir ./checkpoints \
--effective_batch_size 128 \
--max_batch_len 60000 \
--block-size 8192 \
--document-column-name text # optional, defaults to "document"
```

- `--block-size` (required) toggles continual pretraining and controls how many tokens are packed into each block.
- `--document-column-name` (optional) specifies which JSONL field contains the raw document text.

The same options are available programmatically via `TrainingArgs.pretraining_config`:

```python
from instructlab.training import TrainingArgs, PretrainingConfig

train_args = TrainingArgs(
model_name_or_path="mistralai/Mistral-7B-v0.1",
data_path="documents.jsonl",
ckpt_output_dir="./checkpoints",
max_seq_len=4096,
max_batch_len=40000,
effective_batch_size=128,
pretraining_config=PretrainingConfig(
block_size=2048,
document_column_name="text", # optional
),
)
```

When a pretraining config is provided, `process_documents_for_pretraining()` is invoked under the hood to tokenize raw documents before training.

**Standard message structure:**

```json
Expand All @@ -139,7 +180,7 @@ The library now supports an optional `reasoning_content` field in addition to th
}
```

#### Important Notes
### Important Notes

1. **Automatic reasoning content processing**: If `reasoning_content` exists in a message, it will always be processed and unmasked as long as the message role is targeted for unmasking. This ensures that reasoning traces are properly included in the training data.

Expand Down
2 changes: 2 additions & 0 deletions src/instructlab/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"FSDPOptions",
"ShardingStrategies",
"DistributedBackend",
"PretrainingConfig",
)

# First Party
Expand All @@ -23,6 +24,7 @@
DistributedBackend,
FSDPOptions,
LoraOptions,
PretrainingConfig,
QuantizeDataType,
ShardingStrategies,
TorchrunArgs,
Expand Down
32 changes: 32 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,34 @@ class DataProcessArgs(BaseModel):
description="this is the number of CPU procs we use for data processing parallelization",
)

# Pretraining mode flag
is_pretraining: bool = Field(
default=False,
description="Enable pretraining mode: tokenizes raw documents without chat templates or chunking",
)
pretraining_column_name: str = Field(
default="document",
description="the name of the column containing the text to pretrain on",
)

# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())


class PretrainingConfig(BaseModel):
"""
Configuration for pretraining mode.
"""

block_size: int = Field(
description="Size of each block in tokens for pretraining datasets."
)
document_column_name: str = Field(
default="document",
description="Name of the column containing raw documents for pretraining.",
)


# public API
class TorchrunArgs(BaseModel):
"""
Expand Down Expand Up @@ -266,6 +290,14 @@ class TrainingArgs(BaseModel):
# "last_epoch". This works alongside the '--checkpoint_at_epoch' flag.
keep_last_checkpoint_only: Optional[bool] = False

pretraining_config: Optional[PretrainingConfig] = Field(
default="document",
description=(
"Pretraining configuration. When provided, enables block-based sampling "
"for raw document pretraining datasets."
),
)

# TODO(osilkin):
# we are only exposing this here because `run_training` today is implicitly coupled
# with `process_data`. Since we don't have a specific field for data processing arguments,
Expand Down
97 changes: 95 additions & 2 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ def process_messages_into_input_ids_with_chat_template(args: DataProcessArgs):
logger.info("Tokenizing the dataset with %s tokenizer...", args.model_path)
data_with_input_ids = data.map(
lambda x: {
"input_ids": tokenizer.apply_chat_template(x["messages"], tokenize=True),
# newer versions of transformers have `return_dict=True` by default
"input_ids": tokenizer.apply_chat_template(
x["messages"], tokenize=True, return_dict=False
),
"unmask": bool(x["unmask"]) if "unmask" in x else False,
},
num_proc=NUM_PROC,
Expand Down Expand Up @@ -687,7 +690,8 @@ def unmask_messages(
if regions:
message_regions_map[idx] = regions

input_ids = tokenizer.apply_chat_template(msgs_with_unmasking)
# newer versions of transformers have `return_dict=True` by default
input_ids = tokenizer.apply_chat_template(msgs_with_unmasking, return_dict=False)

# Get token IDs for all unmask tokens
unmask_begin_token_id = tokenizer.encode(
Expand Down Expand Up @@ -1133,6 +1137,95 @@ def process_messages_into_input_ids(
save_dataset(final_dataset, data_output_path, num_cpu_procs)


def process_documents_for_pretraining(
data_path: str,
data_output_path: str,
model_path: str,
num_cpu_procs: int,
document_column_name: str = "document",
) -> None:
"""
Process raw documents for pretraining by tokenizing without chunking.

Outputs one JSONL record per document with only input_ids (no labels).
Blocking/chunking happens later during training.

Pattern: Each document → [BOS][tokens][EOS]

Args:
data_path: Path to input JSONL with {"documents": "text"} format
data_output_path: Directory for processed data output
model_path: Path to model/tokenizer
num_cpu_procs: Number of parallel processes
document_column_name: Name of the column containing the documents
"""
ensure_can_write_to_directory(data_output_path)

# Load and validate dataset
try:
data = load_dataset("json", data_files=data_path, split="train")
except Exception as e:
raise ValueError(
"Malformed or missing data, please ensure your dataset is correctly formatted"
) from e

if data.num_rows == 0:
raise ValueError("The provided dataset is empty")

if document_column_name not in data.column_names:
raise ValueError(
f"Pretraining data must have '{document_column_name}' field. Found: {data.column_names}"
)

logger.info("Loading tokenizer from %s", model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)

if tokenizer.eos_token_id is None:
raise ValueError("Tokenizer must have an EOS token defined for pretraining")

logger.info("Tokenizing %d documents for pretraining...", data.num_rows)

# Tokenize each document: encode() adds BOS, then append EOS
def tokenize_document(sample):
input_ids = tokenizer.encode(
sample[document_column_name], add_special_tokens=True
)

# ensures eos token is present without double-adding it.
if input_ids[-1] != tokenizer.eos_token_id:
input_ids.append(tokenizer.eos_token_id)

return {
"input_ids": input_ids,
"len": len(input_ids),
}

tokenized_data = data.map(
tokenize_document,
num_proc=num_cpu_procs,
desc="Tokenizing documents",
remove_columns=data.column_names,
)

# Calculate statistics
total_tokens = sum(tokenized_data["len"])
avg_tokens = total_tokens / len(tokenized_data)
logger.info(f"Processed {len(tokenized_data):,} documents")
logger.info(f"Total tokens: {total_tokens:,}")
logger.info(f"Average tokens per document: {avg_tokens:.1f}")

# Save to JSONL (one record per document)
os.makedirs(data_output_path, exist_ok=True)
output_file = Path(data_output_path) / "data.jsonl"

tokenized_data.to_json(
output_file, num_proc=num_cpu_procs, lines=True, orient="records"
)

logger.info(f"Saved tokenized documents to {output_file}")
logger.info("Note: Blocking into fixed-size chunks will happen during training")


def ensure_can_write_to_directory(output_dir: str) -> None:
"""
Ensure that we can write to the output directory.
Expand Down
76 changes: 57 additions & 19 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
import os
import subprocess
import sys
import time
import warnings

Expand Down Expand Up @@ -47,6 +46,7 @@
from instructlab.training.config import (
DistributedBackend,
ModelTypes,
PretrainingConfig,
TorchrunArgs,
TrainingArgs,
)
Expand Down Expand Up @@ -364,6 +364,7 @@ def main(args):
batch_size = args.effective_batch_size

pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0

train_loader = get_data_loader(
data_path=args.data_path,
batch_size=batch_size,
Expand All @@ -374,6 +375,7 @@ def main(args):
num_workers=8, # I don't like this but am setting it for consistency
flash_enabled=flash_enabled,
pad_token_id=pad_token_id,
pretraining_config=getattr(args, "pretraining_config", None),
)

if args.local_rank == 0:
Expand Down Expand Up @@ -469,18 +471,27 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
)

if train_args.process_data:
# TODO(osilkin):
# Decouple the data processing logic from training.
# Now that we've decided that repos will be less tethered to the
# design choices of the `ilab` CLI, we can make this change.
dp.process_data(
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
num_cpu_procs=train_args.data_process_num_cpu_procs,
)
if train_args.pretraining_config is not None:
dp.process_documents_for_pretraining(
data_path=train_args.data_path,
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
num_cpu_procs=train_args.data_process_num_cpu_procs,
document_column_name=train_args.pretraining_config.document_column_name,
)
else:
# TODO(osilkin):
# Decouple the data processing logic from training.
# Now that we've decided that repos will be less tethered to the
# design choices of the `ilab` CLI, we can make this change.
dp.process_data(
data_output_path=train_args.data_output_dir,
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
num_cpu_procs=train_args.data_process_num_cpu_procs,
)

if not os.path.exists(train_args.ckpt_output_dir):
os.makedirs(train_args.ckpt_output_dir, exist_ok=True)
Expand Down Expand Up @@ -537,6 +548,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
]
)

if train_args.pretraining_config is not None:
command.append(f"--block-size={train_args.pretraining_config.block_size}")
command.append(
f"--document-column-name={train_args.pretraining_config.document_column_name}"
)

if train_args.chat_tmpl_path is not None:
command.append(f"--chat-tmpl-path={train_args.chat_tmpl_path}")

Expand Down Expand Up @@ -647,15 +664,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
return

# wait for the process to exit so we can properly read the exit code
process.wait(timeout=60)
process_code = process.poll()
failure = process_code != 0

if not failure:
return_code = process.wait(timeout=60) # wait for 1 min or error
if return_code == 0:
logger.info("Operation completed successfully! 🎉")
else:
logger.error(
f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}"
f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {return_code}"
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix the misleading error message and incorrect process handling logic.

The process.wait(timeout=60) call blocks until the subprocess exits (or raises TimeoutExpired). When it returns a non-zero return_code, the process has already exited with an error. However, the log message on line 674 incorrectly states "has not exited yet" and the subsequent terminate() call is unnecessary since the process is already dead.

🔎 Proposed fix
-        return_code = process.wait(timeout=60)  # wait for 1 min or error
-        if return_code == 0:
+        try:
+            return_code = process.wait(timeout=60)
+        except subprocess.TimeoutExpired:
+            logger.error(
+                "Training subprocess did not exit within 60s. Sending SIGTERM."
+            )
+            process.terminate()
+            try:
+                process.wait(timeout=60)
+            except subprocess.TimeoutExpired:
+                logger.error("Training subprocess did not terminate, sending SIGKILL.")
+                process.kill()
+            raise RuntimeError("Training subprocess timed out")
+        
+        if return_code == 0:
             logger.info("Operation completed successfully! 🎉")
+            return
         else:
             logger.error(
-                f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {return_code}"
+                f"Training subprocess exited with non-zero code: {return_code}"
             )
-
-        process.terminate()
-        try:
-            logger.info("Waiting for process to exit, 60s...")
-            process.wait(timeout=60)
-        except subprocess.TimeoutExpired:
-            logger.error(
-                "Training subprocess did not terminate before timeout, sending SIGKILL."
-            )
-            process.kill()
+            failure = True

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/instructlab/training/main_ds.py around lines 669-675, the code treats a
non-zero return value from process.wait(timeout=60) as "has not exited yet" and
calls terminate() incorrectly; instead, change the logic so that a non-zero
return_code is logged as a process-exited-with-error (include the return_code
and ideally any stderr output) and do NOT call terminate() in that branch; add
an explicit except subprocess.TimeoutExpired handler around
process.wait(timeout=60) that logs a timeout and then calls process.terminate()
(and optionally process.kill() after a short grace period) to handle processes
that truly do not exit within the timeout.


process.terminate()
Expand Down Expand Up @@ -784,6 +798,18 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
help="Which modules we should target for injecting LoRA layers. Defaults to selecting all projection layers when no values are provided.",
)
parser.add_argument("--max_batch_len", type=int, default=60000)
parser.add_argument(
"--block-size",
type=int,
default=None,
help="When provided, enables pretraining mode with the given token block size.",
)
parser.add_argument(
"--document-column-name",
type=str,
default=None,
help="Column name containing raw documents for continual pretraining data.",
)
parser.add_argument(
"--cpu_offload_optimizer",
action="store_true",
Expand Down Expand Up @@ -856,6 +882,18 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
help="Epsilon for numerical stability in AdamW optimizer.",
)
args = parser.parse_args()
if args.document_column_name is not None and args.block_size is None:
parser.error("--document-column-name requires --block-size to be specified.")

if args.block_size is not None:
pretraining_kwargs = {}
if args.document_column_name is not None:
pretraining_kwargs["document_column_name"] = args.document_column_name
args.pretraining_config = PretrainingConfig(
block_size=args.block_size, **pretraining_kwargs
)
else:
args.pretraining_config = None
set_random_seed(args.seed)
main(args)

Expand Down
Loading
Loading