-
Notifications
You must be signed in to change notification settings - Fork 61
Open
Labels
type:bugSomething isn't workingSomething isn't working
Description
grain==0.2.12
array_record==0.8.1
protobuf==6.32.0
WSL2 / Python 3.12.11
"""Example of Grain DataLoader with CPU-heavy RandomAccessDataSource."""
import time
import numpy as np
import grain
class CPUHeavyDataSource:
"""Simulates a CPU-intensive data source."""
def __init__(self, num_samples: int = 1000):
self.num_samples = num_samples
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, index: int) -> dict:
result = 0
for i in range(2):
result += np.sin(np.arange(200_000)+i)
# Return some data based on the computation
return {
"index": index,
"data": np.random.RandomState(index).randn(128),
"computed": result,
}
class LightTransform(grain.transforms.Map):
"""Light transformation after the heavy data loading."""
def map(self, x: dict) -> dict:
x["data"] = x["data"] * 2.0 # Simple operation
return x
def main():
# Configuration for CPU-heavy computation
from multiprocessing import cpu_count
num_cores = cpu_count()
print(f"System has {num_cores} CPU cores.")
num_cores -= 4
print(f"We will use {num_cores} CPU cores.")
batch_size = 128
num_epochs = 1
# Create data source
data_source = CPUHeavyDataSource(num_samples=batch_size*20) # 20 batches
# Configure read options for CPU-heavy workload
read_options = grain.ReadOptions(
num_threads=0, # No threading within workers (avoids GIL)
prefetch_buffer_size=3, # Minimal buffer since CPU is bottleneck
)
sampler = grain.samplers.IndexSampler(len(data_source), num_epochs=num_epochs)
# Create DataLoader with load API
loader = grain.DataLoader(
data_source=data_source,
sampler=sampler,
operations=[
LightTransform(),
grain.transforms.Batch(batch_size=batch_size, drop_remainder=True),
],
worker_count=num_cores, # Use most cores, leave some for system
read_options=read_options,
)
print(f"Starting data loading with {num_cores} workers, batch_size={batch_size}...")
start_time = time.time()
batch_count = 0
for batch in loader:
batch_mean = np.mean(batch["data"])
batch_count += 1
if batch_count % 10 == 0:
elapsed = time.time() - start_time
samples_processed = batch_count * batch_size
print(f"Processed {batch_count} batches ({samples_processed} samples) "
f"in {elapsed:.1f}s ({samples_processed / elapsed:.1f} samples/sec)")
total_time = time.time() - start_time
total_samples = batch_count * batch_size
print(f"\nCompleted: {total_samples} samples in {total_time:.1f}s "
f"({total_samples / total_time:.1f} samples/sec)")
if __name__ == "__main__":
main()Output:
System has 24 CPU cores.
We will use 20 CPU cores.
Starting data loading with 20 workers, batch_size=128...
Processed 10 batches (1280 samples) in 2.6s (494.8 samples/sec)
Processed 20 batches (2560 samples) in 2.7s (957.2 samples/sec)
Completed: 2560 samples in 2.9s (890.1 samples/sec)
admin@admin-pc:/mnt/c/Users/admin/tmp$ /usr/lib/python3.12/multiprocessing/resource_tracker.py:279: UserWarning: resource_tracker: There appear to be 3 leaked shared_memory objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
/usr/lib/python3.12/multiprocessing/resource_tracker.py:292: UserWarning: resource_tracker: '/psm_2d261897': [Errno 2] No such file or directory: '/psm_2d261897'
warnings.warn('resource_tracker: %r: %s' % (name, e))
PaulScemama and pierrot-lc
Metadata
Metadata
Assignees
Labels
type:bugSomething isn't workingSomething isn't working