-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_checkpoint.py
More file actions
134 lines (110 loc) · 4.94 KB
/
test_checkpoint.py
File metadata and controls
134 lines (110 loc) · 4.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from model.model import DeepSeekConfig, DeepSeekModel
from transformers import AutoTokenizer
import os
import argparse
from pathlib import Path
# Test prompts
TEST_PROMPTS = [
"Explain the concept of quantum entanglement in simple terms:",
"Write a short story about a time traveler who:",
"Here's a recipe for a delicious vegetarian dish:",
"The most fascinating discovery in astronomy was:",
"The future of artificial intelligence will likely involve:"
]
def test_generation_from_checkpoint(checkpoint_path, device='cuda', debug=True):
"""Test generation using a saved checkpoint with detailed error handling."""
print(f"\nTesting checkpoint: {checkpoint_path}")
try:
# Initialize model config
config = DeepSeekConfig()
config.hidden_size = 384
config.intermediate_size = 1024
config.num_attention_heads = 6
config.num_key_value_heads = 2
config.num_hidden_layers = 12
config.max_position_embeddings = 512
# Load tokenizer
print("\nLoading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
config.vocab_size = tokenizer.vocab_size
if debug:
print(f"Vocabulary size: {tokenizer.vocab_size}")
print(f"Padding token ID: {tokenizer.pad_token_id}")
print(f"EOS token ID: {tokenizer.eos_token_id}")
# Initialize model
print("\nInitializing model...")
model = DeepSeekModel(config).to(device)
# Load checkpoint
print("\nLoading checkpoint...")
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
step = checkpoint.get('step', 'unknown')
loss = checkpoint.get('loss', 'unknown')
if debug:
print(f"Checkpoint step: {step}")
print(f"Checkpoint loss: {loss}")
# Set model to eval mode
model.eval()
print("\nStarting generation tests...")
print("=" * 50)
for i, prompt in enumerate(TEST_PROMPTS, 1):
print(f"\nTesting prompt {i}: {prompt}")
try:
# Tokenize with shape logging
input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
if debug:
print(f"Input shape: {input_ids.shape}")
# Generate with error catching
with torch.no_grad():
try:
outputs = model.generate(
input_ids,
max_length=100,
temperature=0.8,
top_k=50
)
if debug:
print(f"Output shape: {outputs.shape}")
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"\nGenerated text:\n{generated_text}")
except RuntimeError as e:
print(f"Generation error: {str(e)}")
print(f"Input tensor shape: {input_ids.shape}")
continue
except Exception as e:
print(f"Error processing prompt {i}: {str(e)}")
continue
print("-" * 50)
except Exception as e:
print(f"\nError during testing: {str(e)}")
finally:
# Cleanup
if 'model' in locals():
del model
torch.cuda.empty_cache()
def main():
parser = argparse.ArgumentParser(description='Test text generation from a checkpoint')
parser.add_argument('--checkpoint', type=str, help='Path to checkpoint file')
parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help='Device to run generation on')
parser.add_argument('--debug', action='store_true', help='Enable debug output')
args = parser.parse_args()
# If no checkpoint specified, use the latest one
if not args.checkpoint:
checkpoints_dir = Path('checkpoints')
if checkpoints_dir.exists():
checkpoints = list(checkpoints_dir.glob('model_step_*.pt'))
if checkpoints:
latest = max(checkpoints, key=lambda p: int(p.stem.split('_')[2]))
args.checkpoint = str(latest)
print(f"Using latest checkpoint: {args.checkpoint}")
else:
print("No checkpoints found in ./checkpoints/")
return
else:
print("Checkpoints directory not found")
return
test_generation_from_checkpoint(args.checkpoint, args.device, args.debug)
if __name__ == '__main__':
main()