diff --git a/catgrad-llm/scripts/compare/compare.sh b/catgrad-llm/scripts/compare/compare.sh index e3cd6b6..808d6f0 100755 --- a/catgrad-llm/scripts/compare/compare.sh +++ b/catgrad-llm/scripts/compare/compare.sh @@ -27,7 +27,26 @@ OUTPUT_DIR="$DIR/outputs" mkdir -p "$OUTPUT_DIR" rm -rf "$OUTPUT_DIR"/* -echo "Generating outputs for ${#MODELS[@]} models..." +MAXLEN="${CATGRAD_COMPARE_MAXLEN:-40}" + +echo "Generating outputs of ${MAXLEN} tokens for ${#MODELS[@]} models..." + +if [[ "${CATGRAD_COMPARE_HF_RUN:-}" ]]; then + REFERENCE_DIR=$DIR/expected/$MAXLEN + mkdir -p $REFERENCE_DIR + + for model in "${MODELS[@]}"; do + # Replace slashes with dashes for the filename + filename="${model//\//-}" + + echo "Running HF Transformers for $model -> $REFERENCE_DIR/$filename" + + uv run catgrad-llm/scripts/llm.py -m "$model" -p 'Category theory is' -s $MAXLEN -r > "$REFERENCE_DIR/$filename" 2>/dev/null & + done + + wait +fi + for model in "${MODELS[@]}"; do # Replace slashes with dashes for the filename @@ -37,7 +56,7 @@ for model in "${MODELS[@]}"; do TYPECHECK="-t" - ./target/release/examples/llama -m "$model" -p 'Category theory is' -s 40 --raw -k $TYPECHECK > "$OUTPUT_DIR/$filename" 2>/dev/null & + ./target/release/examples/llama -m "$model" -p 'Category theory is' -s $MAXLEN --raw -k $TYPECHECK > "$OUTPUT_DIR/$filename" 2>/dev/null & [[ -z "${GITHUB_ACTIONS:-}" ]] || wait done diff --git a/catgrad-llm/scripts/llm.py b/catgrad-llm/scripts/llm.py index daafef3..f9a3702 100644 --- a/catgrad-llm/scripts/llm.py +++ b/catgrad-llm/scripts/llm.py @@ -23,7 +23,7 @@ parser.add_argument("--revision", type=str, default="main") parser.add_argument("-p", "--prompt", type=str, default="Category theory is") parser.add_argument("-s", "--seq-len", type=int, default=10) - parser.add_argument("-r", "--raw-prompt", action="store_true") + parser.add_argument("-r", "--raw", action="store_true") parser.add_argument("-t", "--thinking", action="store_true") parser.add_argument("--cache", dest="use_cache", action=argparse.BooleanOptionalAction, default=True) parser.add_argument("-d", "--dtype", type=str, default="float32") @@ -42,7 +42,7 @@ print(f"Loaded model {args.model}, dtype:{model.dtype} on device {model.device}", file=sys.stderr) prompt = args.prompt - if not args.raw_prompt and tokenizer.chat_template is not None: + if not args.raw and tokenizer.chat_template is not None: chat = [ {"role": "user", "content": prompt}, ] @@ -60,6 +60,6 @@ inputs = tokenizer(prompt, return_token_type_ids=False, return_tensors="pt") logits = model.generate(**inputs, max_new_tokens=args.seq_len, do_sample=False, use_cache=args.use_cache) - output = tokenizer.decode(logits[0]) + output = tokenizer.decode(logits[0], skip_special_tokens=True) print(output)