-
Notifications
You must be signed in to change notification settings - Fork 0
Feature/quantisation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| future_plans/ | ||
| __pycache__/ | ||
| *.pyc | ||
| **/__pycache__/ | ||
| *.pyc | ||
|
|
||
| docs/quantisation.md | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,9 @@ def __init__(self, model, tokenizer, sampler: Sampler | None = None): | |
| @torch.inference_mode() | ||
| def generate(self, prompt: str, max_new_tokens: int = 50, params: SamplingParams | None = None): | ||
| input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.model.config.device) | ||
| prompt_len = input_ids.shape[1] | ||
| past_key_values = None | ||
| prev_text = "" | ||
|
|
||
| for _ in range(max_new_tokens): | ||
| if past_key_values is None: | ||
|
|
@@ -25,5 +27,11 @@ def generate(self, prompt: str, max_new_tokens: int = 50, params: SamplingParams | |
|
|
||
| if next_token.item() == self.tokenizer.eos_token_id: | ||
| break | ||
|
|
||
| yield self.tokenizer.decode(next_token[0], skip_special_tokens=True) | ||
|
|
||
| # Decode all generated tokens so far and yield only the new text. | ||
| # This correctly handles SentencePiece space prefixes and multi-byte chars. | ||
| full_text = self.tokenizer.decode(input_ids[0, prompt_len:], skip_special_tokens=True) | ||
| new_text = full_text[len(prev_text):] | ||
| prev_text = full_text | ||
| if new_text: | ||
| yield new_text | ||
|
Comment on lines
+31
to
+37
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Avoid decoding the entire continuation on every token.
🤖 Prompt for AI Agents |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| """CLI entry point.""" | ||
| import os | ||
| import warnings | ||
| warnings.filterwarnings("ignore", message=".*_check_is_size.*", category=FutureWarning) | ||
| import argparse | ||
| import torch | ||
| from huggingface_hub import try_to_load_from_cache | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from models.weight_loader import load_hf_model | ||
|
|
@@ -15,15 +16,6 @@ | |
| MODEL_ID = "Qwen/Qwen3-0.6B" | ||
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | ||
|
|
||
| # Resolve local cache path if model is already downloaded | ||
| # Passing a local dir path to AutoTokenizer prevents ALL network calls | ||
| _cached = try_to_load_from_cache(MODEL_ID, "config.json") | ||
| LOCAL_MODEL_PATH = os.path.dirname(_cached) if isinstance(_cached, str) else None | ||
|
|
||
| if LOCAL_MODEL_PATH: | ||
| os.environ["HF_HUB_OFFLINE"] = "1" | ||
| os.environ["TRANSFORMERS_OFFLINE"] = "1" | ||
|
|
||
|
|
||
| def parse_args(): | ||
| parser = argparse.ArgumentParser(description="vLLMini Chat") | ||
|
|
@@ -33,6 +25,7 @@ def parse_args(): | |
| parser.add_argument("--temperature", type=float, default=0.7, help="Sampling temperature") | ||
| parser.add_argument("--top-p", type=float, default=0.9, help="Nucleus sampling threshold") | ||
| parser.add_argument("--max-tokens", type=int, default=2048, help="Maximum new tokens to generate") | ||
| parser.add_argument("--quantize", "-q", action="store_true", default=False, help="Enable 4-bit NF4 quantization (requires bitsandbytes)") | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
|
|
@@ -45,15 +38,19 @@ def strip_thinking(output: str) -> str: | |
|
|
||
| def main(): | ||
| args = parse_args() | ||
|
|
||
| # Don't force offline mode for model loading — the weight_loader | ||
| # handles local-first-then-download fallback on its own. | ||
| os.environ.pop("HF_HUB_OFFLINE", None) | ||
| os.environ.pop("TRANSFORMERS_OFFLINE", None) | ||
|
|
||
| model, config = load_hf_model(args.model_id, device=args.device, quantize=args.quantize) | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained(args.model_id) | ||
|
Comment on lines
+42
to
+49
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don’t clear offline mode here. These Suggested fix- os.environ.pop("HF_HUB_OFFLINE", None)
- os.environ.pop("TRANSFORMERS_OFFLINE", None)
+ local_only = (
+ os.environ.get("HF_HUB_OFFLINE") == "1"
+ or os.environ.get("TRANSFORMERS_OFFLINE") == "1"
+ )
model, config = load_hf_model(args.model_id, device=args.device, quantize=args.quantize)
- tokenizer = AutoTokenizer.from_pretrained(args.model_id)
+ tokenizer = AutoTokenizer.from_pretrained(args.model_id, local_files_only=local_only)🧰 Tools🪛 Ruff (0.15.12)[warning] 47-47: Unpacked variable Prefix it with an underscore or any other dummy variable pattern (RUF059) 🤖 Prompt for AI Agents |
||
| # chat = [{"role": "user", "content": "Write a short story about a robot."}] | ||
| # prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | ||
|
|
||
| model, config = load_hf_model(args.model_id, device=args.device) | ||
|
|
||
| # Use local model path if available (from user's caching logic) | ||
| tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH or args.model_id) | ||
| chat = [{"role": "user", "content": "Write a short story about a robot."}] | ||
| prompt = tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) | ||
|
|
||
| # prompt = "Write a very long story about a robot." | ||
| prompt = "Write a very long story about a robot." | ||
|
|
||
| params = SamplingParams(temperature=args.temperature, top_p=args.top_p) | ||
| sampler = Sampler() | ||
|
|
@@ -117,6 +114,11 @@ def main(): | |
| if remainder: | ||
| print(remainder, end="", flush=True) | ||
| parts.append(remainder) | ||
| elif not indicator_shown and len(buffer) > 20: | ||
| # Model doesn't use <think> tags — flush buffer and stream normally | ||
| thinking_done = True | ||
| print(buffer, end="", flush=True) | ||
| parts.append(buffer) | ||
| # Otherwise keep accumulating silently | ||
| else: | ||
| # Either HIDE_THINKING is False, or we're past </think> | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor risk: verify
docs/quantisation.mdshould be ignored.If
docs/quantisation.mdis supposed to be part of the repo’s documentation, ignoring it can lead to accidental omission of doc updates (and if it’s not currently tracked, it won’t be committed going forward).Suggested check / potential fix
-docs/quantisation.md📝 Committable suggestion
🤖 Prompt for AI Agents