Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion ais_bench/benchmark/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ def fill_model_path_if_datasets_need(model_cfg, dataset_cfg):
data_type = get_config_type(dataset_cfg.get("type"))
if data_type in DATASETS_NEED_MODELS:
model_path = model_cfg.get("path")
trust_remote_code = model_cfg.get("trust_remote_code", False)
if not model_path:
raise AISBenchConfigError(
UTILS_CODES.SYNTHETIC_DS_MISS_REQUIRED_PARAM,
"[path] in model config is required for synthetic(tokenid) and sharegpt dataset."
)
dataset_cfg.update({"model_path": model_path})
dataset_cfg.update({"model_path": model_path, "trust_remote_code": trust_remote_code})


def fill_test_range_use_num_prompts(num_prompts: int, dataset_cfg: dict):
if not num_prompts:
Expand Down
3 changes: 2 additions & 1 deletion ais_bench/benchmark/datasets/mooncake_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def load(
fixed_schedule_auto_offset=False,
fixed_schedule_start_offset=0,
fixed_schedule_end_offset=-1,
trust_remote_code=False,
):
"""
Load Mooncake trace dataset
Expand Down Expand Up @@ -507,7 +508,7 @@ def load(

# 2. Load tokenizer
self.logger.info(f"Loading tokenizer from: {model_path}")
tokenizer = load_tokenizer(model_path)
tokenizer = load_tokenizer(model_path, trust_remote_code=trust_remote_code)
self.logger.info(f"Tokenizer loaded successfully, vocab_size: {tokenizer.vocab_size}")

# 3. Load and tokenize corpus
Expand Down
3 changes: 2 additions & 1 deletion ais_bench/benchmark/datasets/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ class ShareGPTDataset(BaseDataset):
@staticmethod
def load(path, disable_shuffle, **kwargs):
tokenizer_path = kwargs.get("model_path", None)
tokenizer = load_tokenizer(tokenizer_path)
trust_remote_code = kwargs.get("trust_remote_code", False)
tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code)
path = get_data_path(path, local_mode=True)
with open(path) as f:
dataset = json.load(f)
Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/datasets/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def load(self, config, **kwargs):
self._check_synthetic_config(config)
request_count = config.get("RequestCount")
config_type = config.get("Type").lower()
trust_remote_code = config.get("TrustRemoteCode")
trust_remote_code = config.get("TrustRemoteCode") or kwargs.get("trust_remote_code", False) # Any place set True will take effect
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic config.get("TrustRemoteCode") or kwargs.get("trust_remote_code", False) might lead to unexpected behavior if TrustRemoteCode is explicitly set to False in the config but trust_remote_code is True in kwargs. While the comment says "Any place set True will take effect", using or will ignore an explicit False from the first source. A more robust way to handle this, if you want to prioritize one or ensure an explicit boolean is respected, would be to check for None specifically.

if config_type == "string":
string_config = config.get("StringConfig")
input_method = string_config["Input"]["Method"]
Expand Down Expand Up @@ -301,7 +301,7 @@ def load(self, config, **kwargs):
tokenizer_file_path = self.find_first_file_path(model_path_value, "tokenizer_config.json")

tokenizer_model = AutoTokenizer.from_pretrained(
os.path.dirname(tokenizer_file_path),
os.path.dirname(tokenizer_file_path),
trust_remote_code=trust_remote_code
)

Expand Down
4 changes: 2 additions & 2 deletions ais_bench/benchmark/summarizers/default_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def _calc_perf_data(
model_cfg: Model configuration
perf_datas: Raw performance data
"""
tokenizer = AISTokenizer(model_cfg.get("path"))
tokenizer = AISTokenizer(model_cfg.get("path"), model_cfg.get("trust_remote_code", False))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In _calc_perf_data, the model_cfg.get("path") is used to initialize AISTokenizer. If the path is missing or None, AISTokenizer (which calls os.path.exists) will raise a TypeError. Although model_path is usually validated earlier in the CLI flow, it's safer to ensure the path exists here or handle the potential None value to avoid process crashes during summarization.

conn = init_db(db_file_path)
all_numpy_data = load_all_numpy_from_db(conn)

Expand Down Expand Up @@ -256,7 +256,7 @@ def _load_details_perf_data(self, model_cfg: dict, dataset_group: list):
details_perf_datas = defaultdict(list)

# check tokenizer
load_tokenizer(tokenizer_path=model_cfg.get("path"))
load_tokenizer(tokenizer_path=model_cfg.get("path"), trust_remote_code=model_cfg.get("trust_remote_code", False))

with multiprocessing.Manager() as manager:
manager_list = manager.list()
Expand Down
8 changes: 4 additions & 4 deletions ais_bench/benchmark/utils/file/load_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = AISLogger()


def load_tokenizer(tokenizer_path: str):
def load_tokenizer(tokenizer_path: str, trust_remote_code=False):
"""Load a tokenizer from the specified path.

Args:
Expand All @@ -33,7 +33,7 @@ def load_tokenizer(tokenizer_path: str):
)

try:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=trust_remote_code)
logger.debug(f"Successfully loaded tokenizer from: {tokenizer_path}")
return tokenizer
except Exception as e:
Expand All @@ -44,8 +44,8 @@ def load_tokenizer(tokenizer_path: str):


class AISTokenizer:
def __init__(self, tokenizer_path: str):
self.tokenizer = load_tokenizer(tokenizer_path)
def __init__(self, tokenizer_path: str, trust_remote_code=False):
self.tokenizer = load_tokenizer(tokenizer_path, trust_remote_code=trust_remote_code)

def encode(self, prompt: list, add_special_tokens: bool = True) -> Tuple[float, List[int]]:
"""Encode a string into tokens, measuring processing time."""
Expand Down
4 changes: 2 additions & 2 deletions tests/UT/utils/file/test_load_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_load_tokenizer_success(self, mock_auto_tokenizer):
result = load_tokenizer(self.tokenizer_path)

self.assertIs(result, mock_tokenizer)
mock_auto_tokenizer.from_pretrained.assert_called_once_with(self.tokenizer_path)
mock_auto_tokenizer.from_pretrained.assert_called_once_with(self.tokenizer_path, trust_remote_code=False)

@patch.object(lt_module, "AutoTokenizer")
def test_load_tokenizer_loading_fails_value_error(self, mock_auto_tokenizer):
Expand Down Expand Up @@ -111,7 +111,7 @@ def test_load_tokenizer_with_absolute_path(self, mock_auto_tokenizer):
result = load_tokenizer(abs_path)

self.assertIs(result, mock_tokenizer)
mock_auto_tokenizer.from_pretrained.assert_called_once_with(abs_path)
mock_auto_tokenizer.from_pretrained.assert_called_once_with(abs_path, trust_remote_code=False)

class TestLoadTokenizerEdgeCases(unittest.TestCase):
"""Edge case tests for load_tokenizer."""
Expand Down
Loading