Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ __marimo__/

# personal files
*technical_architecture.md
*PLAN.md
*test_outputs/
*AGENTS.md
*personal_experimentation/
*uv.lock
10 changes: 10 additions & 0 deletions config/lm_eval_test_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
model: hf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you use it?

model_args: pretrained=gpt2,dtype=float32
tasks:
- hellaswag
batch_size: 2
num_fewshot: 0
output_dir: test_outputs
limit: 3
device: cpu
seed: 42
83 changes: 81 additions & 2 deletions eval_converters/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from schema.eval_types import Family, HfSplit
from schema.eval_types import (
BitPrecision,
Family,
HfSplit,
QuantizationMethod,
QuantizationType)
from transformers import AutoConfig

def detect_family(model_name: str) -> Family:
"""Return the Family enum if any of its values is a substring of model_name."""
Expand All @@ -25,4 +31,77 @@ def detect_hf_split(split_str: str) -> HfSplit:
elif "train" in s:
return HfSplit.train
else:
return HfSplit.validation
return HfSplit.validation

def infer_quantization_from_model_name(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
pass

def infer_quantization_from_model_config(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
pass

def infer_quantization(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
try:
cfg = AutoConfig.from_pretrained(model_name_or_path)
except Exception as e:
return BitPrecision.none, QuantizationMethod.none, QuantizationType.none

qcfg = getattr(cfg, 'quantization_config', None)
if not qcfg:
return BitPrecision.none, QuantizationMethod.none, QuantizationType.none

bits = int(qcfg.get("bits") or qcfg.get("weight_bits") or qcfg.get("q_bits"))

if bits == 8:
precision = BitPrecision.int8
elif bits == 4:
precision = BitPrecision.int4
elif bits == 16:
precision = BitPrecision.float16
elif bits == 32:
precision = BitPrecision.float32
else:
precision = BitPrecision.none

method_key = str(qcfg.get("quant_method") or "").lower()

method_map = {
"gptq": QuantizationMethod.gptq,
"awq": QuantizationMethod.awq,
}

type_map = {
"gptq": QuantizationType.static,
"awq": QuantizationType.static,
"bitsandbytes": QuantizationType.dynamic,
"quanto": QuantizationType.static,
"hqq": QuantizationType.static,
"torchao": QuantizationType.static,
}

qmethod = method_map.get(method_key, QuantizationMethod.none)
qtype = type_map.get(method_key, QuantizationType.none)
return precision, qmethod, qtype

def extract_context_window_from_config(model):
try:
config = AutoConfig.from_pretrained(model)

priority_fields = [
"max_position_embeddings",
"n_positions",
"seq_len",
"seq_length",
"n_ctx",
"sliding_window"
]

context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
if context_window is None:
context_window = 1

except Exception as e:
print(f"Error getting context window: {e}")
context_window = 1

finally:
return context_window
115 changes: 17 additions & 98 deletions eval_converters/helm/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,14 @@
from schema import SCHEMA_VERSION

from eval_converters.common.adapter import BaseEvaluationAdapter, AdapterMetadata, SupportedLibrary
from eval_converters.common.utils import detect_family, detect_hf_split
from eval_converters.common.utils import detect_family, detect_hf_split, infer_quantization, extract_context_window_from_config
from .utils import detect_prompt_class, get_adapter_class_from_method_string

from transformers import AutoConfig

# run this just once in your process to initialize the registry
register_builtin_configs_from_helm_package()

def infer_quantization(model_name_or_path: str):
"""
Returns (BitPrecision, Method) enums for the given HF model.
"""
try:
cfg = AutoConfig.from_pretrained(model_name_or_path)
except Exception as e:
raise ValueError(
f"Failed to load model config for {model_name_or_path}: {e} \n"
"This may happen if you are using a HELM model name instead of HuggingFace model name in the adapter_spec.model field."
"For example, HELM uses 'meta/llama-3.1-8b-instruct' while HuggingFace uses meta-llama/llama-3.1-8b-instruct' \n"
"Please verify the model name and try again."
)
qcfg = getattr(cfg, "quantization_config", None)

if qcfg is None:
return BitPrecision.none, Method.None_

bits = int(qcfg.get("bits") or qcfg.get("weight_bits") or qcfg.get("q_bits"))

if bits == 8:
precision = BitPrecision.int8
elif bits == 4:
precision = BitPrecision.int4
elif bits == 16:
precision = BitPrecision.float16
elif bits == 32:
precision = BitPrecision.float32
else:
precision = BitPrecision.none

method_key = qcfg.get("quant_method") or ""
method_map = {
"gptq": Method.static,
"awq": Method.static,
"bitsandbytes": Method.dynamic,
"quanto": Method.static,
"hqq": Method.static,
"torchao": Method.static,
}

method = method_map.get(method_key, Method.None_)
return precision, method

class HELMAdapter(BaseEvaluationAdapter):
"""
Expand Down Expand Up @@ -148,33 +105,14 @@ def transform_from_directory(self, dir_path):
)

# 1.2. Configuration
# HELM does not provide context window size, try loading it from model config, else set to 1
try:
# try getting context window from model deployment
deployment = get_model_deployment(adapter_spec.model_deployment)
if deployment and deployment.max_sequence_length is not None:
context_window = deployment.max_sequence_length

# if not available, try loading it from model config
else:
config = AutoConfig.from_pretrained(adapter_spec.model)

priority_fields = [
"max_position_embeddings",
"n_positions",
"seq_len",
"seq_length",
"n_ctx",
"sliding_window"
]

context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
if context_window is None:
context_window = 1

except Exception as e:
print(f"Error getting context window: {e}")
context_window = 1
# HELM does not provide context window size, try loading it from model deployment, else set to 1
deployment = get_model_deployment(adapter_spec.model_deployment)
if deployment and deployment.max_sequence_length is not None:
context_window = deployment.max_sequence_length

# if not available, try loading it from model config, else set to 1
else:
context_window = extract_context_window_from_config(adapter_spec.model)

configuration = Configuration(
context_window=context_window,
Expand Down Expand Up @@ -336,33 +274,14 @@ def _transform_single(self, raw_data, base_dir=None):
)

# 1.2. Configuration
# HELM does not provide context window size, try loading it from model config, else set to 1
try:
# try getting context window from model deployment
deployment = get_model_deployment(adapter_spec.model_deployment)
if deployment and deployment.max_sequence_length is not None:
context_window = deployment.max_sequence_length

# if not available, try loading it from model config
else:
config = AutoConfig.from_pretrained(adapter_spec.model)

priority_fields = [
"max_position_embeddings",
"n_positions",
"seq_len",
"seq_length",
"n_ctx",
"sliding_window"
]

context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
if context_window is None:
context_window = 1

except Exception as e:
print(f"Error getting context window: {e}")
context_window = 1
# HELM does not provide context window size, try loading it from model deployment
deployment = get_model_deployment(adapter_spec.model_deployment)
if deployment and deployment.max_sequence_length is not None:
context_window = deployment.max_sequence_length

# if not available, try loading it from model config, else set to 1
else:
context_window = extract_context_window_from_config(adapter_spec.model)

configuration = Configuration(
context_window=context_window,
Expand Down
2 changes: 1 addition & 1 deletion eval_converters/helm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,4 @@ def get_adapter_class_from_method_string(method_str: str) -> type[Adapter]:
if key in method_str:
return mapping[key]

raise ValueError(f"Unknown adapter method string: {method_str}")
raise ValueError(f"Unknown adapter method string: {method_str}")
Loading