-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Open
Labels
bugSomething isn't workingSomething isn't workingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.5.x
Description
Bug description
I'm not sure if this indicates a bug or whether I'm misuing the Lightning software, but at this point I've tried a number of different things and anything I try with FSDP results in crashing code. It's always something with a size mismatch deep in torch. I see it with a CLI version (works fine with single GPU or DDP with multiple GPUs) and with API code like the example below. I do have a strictly pytorch/FSDP implementation that seems to work fine so it's not inherent in pytorch/FSDP.
I'd greatly appreciate knowing if this is a Lightning bug or whether it's something stupid I'm doing.
What version are you seeing the problem on?
v2.5
Reproduced in studio
No response
How to reproduce the bug
import torch
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.strategies import FSDPStrategy
import transformers
LEN=1000
BATCH_SIZE=8
MODEL="NousResearch/Llama-3.2-1B"
L.seed_everything(42)
##
## Wrap a Llama model in the Lightning wrapper
##
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
class LlamaPretrainedModule(L.LightningModule):
def __init__(self, model_name="meta-llama/Llama-2-7b-hf"):
super().__init__()
self.model = AutoModelForCausalLM.from_config(
AutoConfig.from_pretrained(model_name)
)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.save_hyperparameters()
def forward(self, input_ids, attention_mask=None, labels=None):
return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
def training_step(self, batch, batch_idx):
output = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
loss = output.loss
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=1e-5)
return optimizer
##
## Create data
##
from datasets import Dataset
from transformers import DataCollatorForLanguageModeling
def tokenize(dataset) :
return tokenizer(dataset['text'], padding='max_length', max_length=32, return_length=True)
dataset = Dataset.from_list(
[{"text":"this is a line of text {}".format(i)} for i in range(10)]
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenizer.pad_token = tokenizer.eos_token
tokenized_dataset = dataset.map(
tokenize, batched=True, remove_columns=dataset.column_names
)
train_dataloader = DataLoader(tokenized_dataset,
batch_size=BATCH_SIZE,
collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False),
num_workers=0
)
##
## Load Model
##
torch.set_num_threads(16)
torch.set_float32_matmul_precision('high')
model = LlamaPretrainedModule(MODEL)
##
## Setup and Run Trainer
##
policy = {transformers.models.llama.modeling_llama.LlamaDecoderLayer, torch.nn.Embedding}
strategy = FSDPStrategy(
sharding_strategy="FULL_SHARD",
auto_wrap_policy=policy,
#activation_checkpointing_policy=policy # same as autowrap
)
trainer = L.Trainer(
accelerator="cuda",
devices=4,
max_epochs=1,
strategy=strategy,
enable_checkpointing=False
)
trainer.fit(model, train_dataloader)
trainer.print(torch.cuda.memory_summary())Error messages and logs
[gpu:Lightning] CUDA_VISIBLE_DEVICES=1,2,3,4 python llamafsdp.py
Seed set to 42
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 141.20 examples/s]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/4
[rank: 1] Seed set to 42
[rank: 3] Seed set to 42
[rank: 2] Seed set to 42
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 254.08 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 254.72 examples/s]
Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 211.39 examples/s]
Initializing distributed: GLOBAL_RANK: 2, MEMBER: 3/4
Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/4
Initializing distributed: GLOBAL_RANK: 3, MEMBER: 4/4
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 4 processes
----------------------------------------------------------------------------------------------------
LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [1,2,3,4]
LOCAL_RANK: 2 - CUDA_VISIBLE_DEVICES: [1,2,3,4]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [1,2,3,4]
LOCAL_RANK: 3 - CUDA_VISIBLE_DEVICES: [1,2,3,4]
| Name | Type | Params | Mode
---------------------------------------------------
0 | model | LlamaForCausalLM | 308 M | train
---------------------------------------------------
308 M Trainable params
0 Non-trainable params
308 M Total params
1,235.814 Total estimated model params size (MB)
232 Modules in train mode
0 Modules in eval mode
/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:436: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (1) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0: 0%| | 0/1 [00:00<?, ?it/s][rank2]: Traceback (most recent call last):
[rank2]: File "/home/user/Lightning/llamafsdp.py", line 82, in <module>
[rank2]: trainer.fit(model, train_dataloader)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank2]: call._call_and_handle_interrupt(
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt
[rank2]: return trainer_fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank2]: self._run(model, ckpt_path=ckpt_path)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank2]: results = self._run_stage()
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank2]: self.fit_loop.run()
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank2]: self.advance()
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 458, in advance
[rank2]: self.epoch_loop.run(self._data_fetcher)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 152, in run
[rank2]: self.advance(data_fetcher)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 348, in advance
[rank2]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank2]: self._optimizer_step(batch_idx, closure)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank2]: call._call_lightning_module_hook(
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook
[rank2]: output = fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1366, in optimizer_step
[rank2]: optimizer.step(closure=optimizer_closure)
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank2]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank2]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/fsdp.py", line 157, in optimizer_step
[rank2]: return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
[rank2]: return optimizer.step(closure=closure, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank2]: out = func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank2]: ret = func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/adam.py", line 226, in step
[rank2]: loss = closure()
[rank2]: ^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
[rank2]: closure_result = closure()
[rank2]: ^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank2]: self._result = self.closure(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank2]: return func(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
[rank2]: step_output = self._step_fn()
[rank2]: ^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
[rank2]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 333, in _call_strategy_hook
[rank2]: output = fn(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
[rank2]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
[rank2]: wrapper_output = wrapper_module(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank2]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
[rank2]: out = method(*_args, **_kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/user/Lightning/llamafsdp.py", line 29, in training_step
[rank2]: output = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/home/user/Lightning/llamafsdp.py", line 26, in forward
[rank2]: return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
[rank2]: output = func(self, *args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank2]: logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank2]: return F.linear(input, self.weight, self.bias)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: size mismatch, got input (96), mat (96x2048), vec (65667072)
[rank1]: Traceback (most recent call last):
[rank1]: File "/home/user/Lightning/llamafsdp.py", line 82, in <module>
[rank1]: trainer.fit(model, train_dataloader)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank1]: call._call_and_handle_interrupt(
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt
[rank1]: return trainer_fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank1]: self._run(model, ckpt_path=ckpt_path)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank1]: results = self._run_stage()
[rank1]: ^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank1]: self.fit_loop.run()
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank1]: self.advance()
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 458, in advance
[rank1]: self.epoch_loop.run(self._data_fetcher)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 152, in run
[rank1]: self.advance(data_fetcher)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 348, in advance
[rank1]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank1]: self._optimizer_step(batch_idx, closure)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank1]: call._call_lightning_module_hook(
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook
[rank1]: output = fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1366, in optimizer_step
[rank1]: optimizer.step(closure=optimizer_closure)
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank1]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank1]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/fsdp.py", line 157, in optimizer_step
[rank1]: return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
[rank1]: return optimizer.step(closure=closure, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank1]: out = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank1]: ret = func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/adam.py", line 226, in step
[rank1]: loss = closure()
[rank1]: ^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
[rank1]: closure_result = closure()
[rank1]: ^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank1]: self._result = self.closure(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank1]: return func(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
[rank1]: step_output = self._step_fn()
[rank1]: ^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
[rank1]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 333, in _call_strategy_hook
[rank1]: output = fn(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
[rank1]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
[rank1]: wrapper_output = wrapper_module(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank1]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
[rank1]: out = method(*_args, **_kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/user/Lightning/llamafsdp.py", line 29, in training_step
[rank1]: output = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/home/user/Lightning/llamafsdp.py", line 26, in forward
[rank1]: return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
[rank1]: output = func(self, *args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank1]: logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank1]: return self._call_impl(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank1]: return forward_call(*args, **kwargs)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank1]: return F.linear(input, self.weight, self.bias)
[rank1]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]: RuntimeError: size mismatch, got input (96), mat (96x2048), vec (65667072)
[rank0]: Traceback (most recent call last):
[rank0]: File "/home/user/Lightning/llamafsdp.py", line 82, in <module>
[rank0]: trainer.fit(model, train_dataloader)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank0]: call._call_and_handle_interrupt(
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 48, in _call_and_handle_interrupt
[rank0]: return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
[rank0]: return function(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank0]: self._run(model, ckpt_path=ckpt_path)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank0]: results = self._run_stage()
[rank0]: ^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank0]: self.fit_loop.run()
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank0]: self.advance()
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 458, in advance
[rank0]: self.epoch_loop.run(self._data_fetcher)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 152, in run
[rank0]: self.advance(data_fetcher)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 348, in advance
[rank0]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank0]: self._optimizer_step(batch_idx, closure)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank0]: call._call_lightning_module_hook(
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook
[rank0]: output = fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1366, in optimizer_step
[rank0]: optimizer.step(closure=optimizer_closure)
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank0]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank0]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/fsdp.py", line 157, in optimizer_step
[rank0]: return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
[rank0]: return optimizer.step(closure=closure, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank0]: out = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank0]: ret = func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/adam.py", line 226, in step
[rank0]: loss = closure()
[rank0]: ^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
[rank0]: closure_result = closure()
[rank0]: ^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank0]: self._result = self.closure(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
[rank0]: step_output = self._step_fn()
[rank0]: ^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
[rank0]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 333, in _call_strategy_hook
[rank0]: output = fn(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
[rank0]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
[rank0]: wrapper_output = wrapper_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
[rank0]: out = method(*_args, **_kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/Lightning/llamafsdp.py", line 29, in training_step
[rank0]: output = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/home/user/Lightning/llamafsdp.py", line 26, in forward
[rank0]: return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
[rank0]: output = func(self, *args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank0]: logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank0]: return F.linear(input, self.weight, self.bias)
[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: size mismatch, got input (96), mat (96x2048), vec (65667072)
[rank3]: Traceback (most recent call last):
[rank3]: File "/home/user/Lightning/llamafsdp.py", line 82, in <module>
[rank3]: trainer.fit(model, train_dataloader)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 560, in fit
[rank3]: call._call_and_handle_interrupt(
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 49, in _call_and_handle_interrupt
[rank3]: return trainer_fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 598, in _fit_impl
[rank3]: self._run(model, ckpt_path=ckpt_path)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1011, in _run
[rank3]: results = self._run_stage()
[rank3]: ^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1055, in _run_stage
[rank3]: self.fit_loop.run()
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 216, in run
[rank3]: self.advance()
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 458, in advance
[rank3]: self.epoch_loop.run(self._data_fetcher)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 152, in run
[rank3]: self.advance(data_fetcher)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 348, in advance
[rank3]: batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 192, in run
[rank3]: self._optimizer_step(batch_idx, closure)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 270, in _optimizer_step
[rank3]: call._call_lightning_module_hook(
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 177, in _call_lightning_module_hook
[rank3]: output = fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1366, in optimizer_step
[rank3]: optimizer.step(closure=optimizer_closure)
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 154, in step
[rank3]: step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 239, in optimizer_step
[rank3]: return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/fsdp.py", line 157, in optimizer_step
[rank3]: return super().optimizer_step(optimizer, model=model, closure=closure, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 123, in optimizer_step
[rank3]: return optimizer.step(closure=closure, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 517, in wrapper
[rank3]: out = func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/optimizer.py", line 82, in _use_grad
[rank3]: ret = func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/optim/adam.py", line 226, in step
[rank3]: loss = closure()
[rank3]: ^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 109, in _wrap_closure
[rank3]: closure_result = closure()
[rank3]: ^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 146, in __call__
[rank3]: self._result = self.closure(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
[rank3]: return func(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 131, in closure
[rank3]: step_output = self._step_fn()
[rank3]: ^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 319, in _training_step
[rank3]: training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 333, in _call_strategy_hook
[rank3]: output = fn(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
[rank3]: return self._forward_redirection(self.model, self.lightning_module, "training_step", *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 641, in __call__
[rank3]: wrapper_output = wrapper_module(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank3]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 634, in wrapped_forward
[rank3]: out = method(*_args, **_kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/user/Lightning/llamafsdp.py", line 29, in training_step
[rank3]: output = self(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/home/user/Lightning/llamafsdp.py", line 26, in forward
[rank3]: return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/utils/generic.py", line 918, in wrapper
[rank3]: output = func(self, *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/transformers/models/llama/modeling_llama.py", line 474, in forward
[rank3]: logits = self.lm_head(hidden_states[:, slice_indices, :])
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/opt/AI/training-2.9.0/lib/python3.12/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank3]: return F.linear(input, self.weight, self.bias)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: RuntimeError: size mismatch, got input (96), mat (96x2048), vec (65667072)
Environment
Python 3.12.12
torch 2.9.1
pytorch-lightning 2.5.6
transformers 4.57.1
More info
No response
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingstrategy: fsdpFully Sharded Data ParallelFully Sharded Data Parallelver: 2.5.x