Skip to content

Commit 2d9a17a

Browse files
committed
Refactor bucketing, disable torch.compile by default
1 parent a53d6e1 commit 2d9a17a

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

src/instructlab/training/main_ds.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def setup_model(
182182
else:
183183
model = AutoModelForCausalLM.from_pretrained(**base_model_args)
184184

185-
if is_torch_hpu_available():
185+
if is_torch_hpu_available() and os.getenv("HPU_ENABLE_TORCH_COMPILE", False):
186186
torch._dynamo.config.cache_size_limit = int(1e4)
187187
torch._dynamo.config.accumulated_cache_size_limit = int(2e4)
188188
model = torch.compile(model, backend="hpu_backend", dynamic=False)

src/instructlab/training/multipack_sampler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,14 @@ def get_effective_samples_per_minibatch(num_tokens_per_gpu):
7070
7171
The function creates a sampler using the MultipackDistributedBatchSampler class, generates batches using the sampler, and then returns the ratio of the dataset size to the number of batches.
7272
"""
73+
lengths=dataset.get_lengths()
74+
if is_torch_hpu_available():
75+
bucket_v = np.vectorize(bucket)
76+
lengths = bucket_v(lengths)
77+
7378
sampler = MultipackDistributedBatchSampler(
7479
batch_max_length=num_tokens_per_gpu,
75-
lengths=dataset.get_lengths(),
80+
lengths=lengths,
7681
num_replicas=torch.distributed.get_world_size(),
7782
rank=torch.distributed.get_rank(),
7883
seed=seed,
@@ -397,11 +402,6 @@ def generate_batches(self, set_stats=False):
397402
)
398403

399404
lengths = self.lengths[indices]
400-
401-
if is_torch_hpu_available():
402-
bucket_v = np.vectorize(bucket)
403-
lengths = bucket_v(lengths)
404-
405405
lengths_cumsum = np.cumsum(lengths)
406406

407407
batches, total_used, total_slots = allocate(

src/instructlab/training/token_dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from instructlab.training.multipack_sampler import MultipackDistributedBatchSampler
1414
from instructlab.training.utils import log_rank_0, make_collate_fn
1515

16+
from instructlab.training.hpu_utils import is_torch_hpu_available, bucket
1617

1718
class TokenDataset(Dataset):
1819
def __init__(self, data_path):
@@ -109,6 +110,10 @@ def setup_dataloader(
109110

110111
lengths = dataset.get_lengths()
111112
if sampler == "multipack":
113+
if is_torch_hpu_available():
114+
bucket_v = np.vectorize(bucket)
115+
lengths = bucket_v(lengths)
116+
112117
sampler = MultipackDistributedBatchSampler(
113118
batch_max_length=packing_max_batch_len,
114119
lengths=lengths,

0 commit comments

Comments
 (0)