From 974f28f48dc7462ce5643e0756d24b10d9e66f78 Mon Sep 17 00:00:00 2001 From: Shadi Noghabi Date: Mon, 26 Jan 2026 16:19:30 -0800 Subject: [PATCH] Allow config_id as an alternative model_id to automodel we don't want to force using the HF id as the only option for the model_id, especially for other sources, such as Kaggle/GCS/CNS. Alternatively, we allow the config_id from the model.py for each model family as the option to be used for model_id (and respectively model_name) for cli and Automodel. Tested by local launch of `./examples/rl/grpo/gsm8k/run_gemma2_2b.sh` on VM PiperOrigin-RevId: 861402268 --- docs/models.md | 40 +++--- examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml | 4 +- examples/rl/grpo/gsm8k/run_gemma_7b.sh | 4 +- examples/sft/mtnt/configs/gemma2_2b.yaml | 4 +- examples/sft/mtnt/run_gemma_2b.sh | 4 +- tests/models/naming_test.py | 126 ++++++++++++++++-- tests/smoke_tests/model_creation_test.py | 4 +- tunix/models/automodel.py | 45 ++++++- tunix/models/naming.py | 90 ++++++++++--- 9 files changed, 251 insertions(+), 70 deletions(-) diff --git a/docs/models.md b/docs/models.md index cd758a90d..2ddee510c 100644 --- a/docs/models.md +++ b/docs/models.md @@ -62,7 +62,7 @@ Adding the new model needs to following the naming convention that Tunix support ## AutoModel `AutoModel` provides a unified interface for instantiating Tunix models from -pretrained checkpoints, similar to the Hugging Face `AutoModel` API. It allows +pretrained checkpoints, similar to the Huggingface `AutoModel` API. It allows you to load a model simply by providing its `model_id`, handling the download and initialization for you. @@ -70,7 +70,7 @@ and initialization for you. To load a model, use the `AutoModel.from_pretrained` method with the model identifier and your JAX sharding mesh. By default this will download the model -from HuggingFace. +from Huggingface. ```python from tunix.models.automodel import AutoModel @@ -80,9 +80,9 @@ import jax mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2) # 2. Load the model -# By default, this downloads from Hugging Face. +# By default, this downloads from Huggingface. model, model_path = AutoModel.from_pretrained( - model_id="google/gemma-2-2b-it", + model_id="google/gemma-2-2b-it", # Using HF id as model_id mesh=mesh ) @@ -94,20 +94,19 @@ print(f"Model loaded from: {model_path}") You can load models from different sources (e.g., Kaggle, GCS, etc.) using the `model_source` argument. -#### From HuggingFace: +#### From Huggingface: This is the default choice (`ModelSource.HUGGINGFACE`) as shown in the example above. #### From Kaggle: -For Kaggle, you must provide the `model_id` which is the Hugging Face identifier -(to determine the model configuration) and the `model_path` which is the Kaggle +For Kaggle, you must provide the `model_id` which is the Huggingface identifier or model_config_id (see [Naming Conventions](models.md#naming-conventions)) to determine the model configuration and the `model_path` which is the Kaggle Hub model identifier (used to download the model from Kaggle). ```python model, model_path = AutoModel.from_pretrained( - model_id="google/gemma2-2b-it", + model_id="gemma2_2b_it", # Using model_config_id as model_id mesh=mesh, model_source=ModelSource.KAGGLE, model_path="google/gemma-2/flax/gemma2-2b-it", @@ -120,13 +119,12 @@ For example the `model_path` for the `google/gemma-2/flax/gemma2-2b-it` is extra #### From GCS: -For GCS, you must provide the `model_id` which is the Hugging Face identifier -(to determine the model configuration) and the `model_path` (the actual GCS +For GCS, you must provide the `model_id` which is the Huggingface identifier or model_config_id (see [Naming Conventions](models.md#naming-conventions)) to determine the model configuration and the `model_path` (the actual GCS location). ```python model, model_path = AutoModel.from_pretrained( - model_id="google/gemma-2-2b-it", + model_id="gemma2_2b_it", # Using model_config_id as model_id mesh=mesh, model_source=ModelSource.GCS, model_path="gs://my-bucket/gemma-2-2b-it" @@ -139,7 +137,7 @@ Optionally, you can also provide the `model_download_path` argument, which specifies where the model is to be downloaded to. Depending on the `model_source` the effect of specifying this variable is different: -* **Hugging Face**: Files are downloaded directly to this directory. +* **Huggingface**: Files are downloaded directly to this directory. * **Kaggle**: Sets the `KAGGLEHUB_CACHE` environment variable to this path. * **GCS**: No-op. * **Internal**: Files are copied to this directory. If omitted, the model is loaded directly from the `model_path`. This mode (Internal) is not supported in OSS version. @@ -148,21 +146,27 @@ specifies where the model is to be downloaded to. Depending on the This section outlines the naming conventions used within Tunix for model identification and configuration. These conventions ensure consistency when -loading models from various sources like Hugging Face or Kaggle. +loading models from various sources like Huggingface or Kaggle. The `ModelNaming` dataclass handles the parsing and standardization of model names. -* **`model_id`**: The full model name identifier (case sensitive), as it appears - on Hugging Face, including the parent directory. For example, +* **`model_id`**: This is a unique identifier used to identifty the model in mind and extract the family, version, and desired config from. Tunix support two identifiers as the `model_id`: + 1. **Huggingface (HF) IDs:** The full model name identifier (case sensitive), as it appears + on Huggingface, including the parent directory. + * **Extracting model_id from HF**: For example, `meta-llama/Llama-3.1-8B` is extracted as shown below: - ![Hugging Face extracting Model ID](images/model_id_huggingface.png){: width="75%"} + ![Huggingface extracting Model ID](images/model_id_huggingface.png){: width="75%"} + + 2. **Native Tunix model_configs:** the `model_config_id` representing the exact config from the model class can be used directly as the `model_id`. In this case it will also be treated as the `model_name`. + * **Extracting model_id from model_config_id**: In this case, you would need to refer to the source code (`model.py`) for each model family and select the config id from the `ModelConfig` class, for example `llama3p1_8b` from the llama [model code](https://github.com/google/tunix/blob/main/models/llama3/model.py;bpv=1;bpt=1;l=138). * **`model_name`**: The unique full name identifier of the model. This corresponds to the full name and should match exactly with the model name used in Hugging Face or Kaggle. It is typically all lowercase and formatted - as `-`. - * *Example*: `gemma-2b`, `llama-3.1-8b`, `gemma2-2b-it`. + as `-` (when HF is used for model_id) or `_` (when model_config_id is used for model_id) . + * *Example for HF as model_id*: `gemma-2b`, `llama-3.1-8b`, `gemma-2-2b-it`. + * *Example for model_config_id as model_id*: `gemma_2b`, `llama3p1_8b`, `gemma2_2b_it`. * **`model_family`**: The standardized model family. Unnecessary hyphens are removed, and versions are standardized (e.g., replacing dot with `p`). diff --git a/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml b/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml index 91f50621c..951685628 100644 --- a/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml +++ b/examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml @@ -13,8 +13,8 @@ # limitations under the License. model_config: - model_name: "gemma-2-2b-it" - model_id: "google/gemma-2-2b-it" + model_name: "gemma2_2b_it" + model_id: "gemma2_2b_it" model_path: "google/gemma-2/flax/gemma2-2b-it" model_source: "kaggle" mesh: diff --git a/examples/rl/grpo/gsm8k/run_gemma_7b.sh b/examples/rl/grpo/gsm8k/run_gemma_7b.sh index 1e1f365dd..9cda82a6e 100755 --- a/examples/rl/grpo/gsm8k/run_gemma_7b.sh +++ b/examples/rl/grpo/gsm8k/run_gemma_7b.sh @@ -41,8 +41,8 @@ echo "Rounded warmup steps: $warmup_steps" python3 -m tunix.cli.grpo_main \ base_config.yaml \ - model_config.model_name="gemma-7b-it" \ - model_config.model_id="google/gemma-7b-it" \ + model_config.model_name="gemma_7b_it" \ + model_config.model_id="gemma_7b_it" \ model_config.model_path="google/gemma/flax/7b-it" \ model_config.model_source="kaggle" \ model_config.model_download_path="/tmp/models/gemma-7b" \ diff --git a/examples/sft/mtnt/configs/gemma2_2b.yaml b/examples/sft/mtnt/configs/gemma2_2b.yaml index 2b6b6bba8..a07444622 100644 --- a/examples/sft/mtnt/configs/gemma2_2b.yaml +++ b/examples/sft/mtnt/configs/gemma2_2b.yaml @@ -13,8 +13,8 @@ # limitations under the License. model_config: - model_name: "gemma-2-2b-it" - model_id: "google/gemma-2-2b-it" + model_name: "gemma2_2b_it" + model_id: "gemma2_2b_it" model_path: "google/gemma-2/flax/gemma2-2b-it" model_source: "kaggle" mesh: diff --git a/examples/sft/mtnt/run_gemma_2b.sh b/examples/sft/mtnt/run_gemma_2b.sh index d471aa9a2..609ca5bfe 100755 --- a/examples/sft/mtnt/run_gemma_2b.sh +++ b/examples/sft/mtnt/run_gemma_2b.sh @@ -17,8 +17,8 @@ set -x # Enable xtrace python3 -m tunix.cli.peft_main \ base_config.yaml \ - model_config.model_name="gemma-2b" \ - model_config.model_id="google/gemma-2b" \ + model_config.model_name="gemma_2b" \ + model_config.model_id="gemma_2b" \ model_config.model_path="google/gemma/flax/2b" \ model_config.model_source="kaggle" \ model_config.model_download_path="/tmp/models" \ diff --git a/tests/models/naming_test.py b/tests/models/naming_test.py index f4f27f27e..27b1daa90 100644 --- a/tests/models/naming_test.py +++ b/tests/models/naming_test.py @@ -488,6 +488,41 @@ def _get_test_cases_for_model_id_exists() -> list[dict[str, str]]: ] +def _get_test_cases_for_auto_population_with_HF_model_id() -> ( + list[dict[str, str]] +): + test_cases = [] + _validate_full_model_coverage() + for model_info in _TEST_MODEL_INFOS: + test_cases.append({ + 'testcase_name': model_info.config_id, + 'model_id': model_info.id, + 'expected_name': model_info.name, + 'expected_family': model_info.family, + 'expected_version': model_info.version, + 'expected_category': model_info.category, + 'expected_config_id': model_info.config_id, + }) + return test_cases + + +def _get_test_cases_for_auto_population_with_config_id() -> ( + list[dict[str, str]] +): + test_cases = [] + _validate_full_model_coverage() + for model_info in _TEST_MODEL_INFOS: + test_cases.append({ + 'testcase_name': model_info.config_id, + 'model_id': model_info.config_id, + 'expected_family': model_info.family, + 'expected_version': model_info.version, + 'expected_category': model_info.category, + 'expected_config_id': model_info.config_id, + }) + return test_cases + + class TestNaming(parameterized.TestCase): @parameterized.named_parameters( @@ -501,9 +536,15 @@ def test_get_model_name_from_model_id( expected_name, ) - def test_get_model_name_from_model_id_invalid_fails(self): - with self.assertRaisesRegex(ValueError, 'Invalid model ID format'): - naming.get_model_name_from_model_id('Llama-3.1-8B') + def test_get_model_name_from_model_id_no_slash_succeeds(self): + self.assertEqual( + naming.get_model_name_from_model_id('Llama-3.1-8B'), 'llama-3.1-8b' + ) + + def test_get_model_name_from_model_id_config_id(self): + self.assertEqual( + naming.get_model_name_from_model_id('llama3p1_8b'), 'llama3p1_8b' + ) def test_get_model_name_from_model_id_nested_path(self): self.assertEqual( @@ -544,7 +585,15 @@ def test_get_model_family_and_version( def test_get_model_family_and_version_invalid_fails(self): with self.assertRaisesRegex( - ValueError, 'Could not determine model family for: foobar.' + ValueError, 'Could not determine model family for: foo-bar.' + ): + naming.get_model_family_and_version('foo-bar') + + def test_get_model_family_and_version_invalid_format_fails(self): + with self.assertRaisesRegex( + ValueError, + 'Invalid model ID format: .* Expected a Huggingface model ID or a' + ' ConfigId.', ): naming.get_model_family_and_version('foobar') @@ -555,6 +604,8 @@ def test_get_model_family_and_version_invalid_version_fails(self): def test_split(self): self.assertEqual(naming.split('gemma-7b'), ('gemma-', '7b')) self.assertEqual(naming.split('gemma-1.1-7b'), ('gemma-1.1-', '7b')) + self.assertEqual(naming.split('gemma_7b'), ('gemma_', '7b')) + self.assertEqual(naming.split('gemma1p1_7b'), ('gemma1p1_', '7b')) @parameterized.named_parameters(_get_test_cases_for_get_model_config_id()) def test_get_model_config_id(self, model_name: str, expected_config_id: str): @@ -570,15 +621,60 @@ def test_get_model_config_category( naming.get_model_config_category(model_name), expected_category ) - def test_model_naming_auto_population(self): - model_id = 'meta-llama/Llama-3.1-8B' - naming_info = naming.ModelNaming(model_id=model_id) - self.assertEqual(naming_info.model_id, model_id) - self.assertEqual(naming_info.model_name, 'llama-3.1-8b') - self.assertEqual(naming_info.model_family, 'llama3p1') - self.assertEqual(naming_info.model_version, '8b') - self.assertEqual(naming_info.model_config_category, 'llama3') - self.assertEqual(naming_info.model_config_id, 'llama3p1_8b') + @parameterized.named_parameters( + _get_test_cases_for_auto_population_with_HF_model_id() + ) + def test_model_naming_auto_population_with_HF_model_id( + self, + *, + model_id: str, + expected_name: str, + expected_family: str, + expected_version: str, + expected_category: str, + expected_config_id: str, + ): + with self.subTest(name='Test Model naming creation with HFModelId'): + naming_info = naming.ModelNaming(model_id=naming.HFModelId(model_id)) + self.assertEqual(naming_info.model_id, model_id) + self.assertEqual(naming_info.model_name, expected_name) + self.assertEqual(naming_info.model_family, expected_family) + self.assertEqual(naming_info.model_version, expected_version) + self.assertEqual(naming_info.model_config_category, expected_category) + self.assertEqual(naming_info.model_config_id, expected_config_id) + + with self.subTest(name='Test Model id type detection'): + self.assertTrue(naming._is_hf_model_id_type(naming_info.model_id)) + self.assertFalse(naming._is_config_id_type(naming_info.model_id)) + self.assertTrue(naming._is_hf_model_id_type(naming_info.model_name)) + self.assertFalse(naming._is_config_id_type(naming_info.model_name)) + + @parameterized.named_parameters( + _get_test_cases_for_auto_population_with_config_id() + ) + def test_model_naming_auto_population_with_config_id_model_id( + self, + *, + model_id: str, + expected_family: str, + expected_version: str, + expected_category: str, + expected_config_id: str, + ): + with self.subTest(name='Test Model naming creation with ConfigId'): + naming_info = naming.ModelNaming(model_id=naming.ConfigId(model_id)) + self.assertEqual(naming_info.model_id, model_id) + self.assertEqual(naming_info.model_name, model_id) + self.assertEqual(naming_info.model_family, expected_family) + self.assertEqual(naming_info.model_version, expected_version) + self.assertEqual(naming_info.model_config_category, expected_category) + self.assertEqual(naming_info.model_config_id, expected_config_id) + + with self.subTest(name='Test Model id type detection'): + self.assertFalse(naming._is_hf_model_id_type(naming_info.model_id)) + self.assertTrue(naming._is_config_id_type(naming_info.model_id)) + self.assertFalse(naming._is_hf_model_id_type(naming_info.model_name)) + self.assertTrue(naming._is_config_id_type(naming_info.model_name)) def test_model_naming_no_model_id(self): model_name = 'gemma-2b' @@ -608,7 +704,9 @@ def test_model_naming_mismatch(self): 'model_name set in ModelNaming and one inferred from model_id do not' ' match', ): - naming.ModelNaming(model_name='gemma-7b', model_id='google/gemma-2b') + naming.ModelNaming( + model_name='gemma-7b', model_id=naming.HFModelId('google/gemma-2b') + ) if __name__ == '__main__': diff --git a/tests/smoke_tests/model_creation_test.py b/tests/smoke_tests/model_creation_test.py index f8f67070b..f3e93daa9 100644 --- a/tests/smoke_tests/model_creation_test.py +++ b/tests/smoke_tests/model_creation_test.py @@ -55,9 +55,9 @@ def tearDown(self): ), dict( testcase_name="gemma2_2b_it", - model_name="gemma-2-2b-it", + model_name="gemma2_2b_it", model_source="kaggle", - model_id="google/gemma-2-2b-it", + model_id="gemma2_2b_it", model_path="google/gemma-2/flax/gemma2-2b-it", tokenizer_path=model._DEFAULT_TOKENIZER_PATH, tokenizer_type="sentencepiece", diff --git a/tunix/models/automodel.py b/tunix/models/automodel.py index 8ef1a9c9d..07817b77d 100644 --- a/tunix/models/automodel.py +++ b/tunix/models/automodel.py @@ -148,8 +148,25 @@ def create_gemma_model_with_nnx_conversion( intermediate_ckpt_dir: str, rng_seed: int, mesh: jax.sharding.Mesh, + model_path: str | None = None, ) -> tuple[nnx.Module, Any]: - """Creates a Gemma model with NNX conversion, using a cached checkpoint if available.""" + """Creates a Gemma model with NNX conversion, using a cached checkpoint if available. + + Args: + model_name: The name of the model (e.g., "gemma-2b"). + ckpt_path: The base path to the checkpoints. + intermediate_ckpt_dir: Directory to save or load the NNX converted + checkpoint. + rng_seed: The random seed for model initialization. + mesh: Mesh object for device layout. + model_path: Optional. The specific path to the model files. If None, + the path is inferred from `model_name` and `ckpt_path`. + + Returns: + A tuple containing: + - model: The loaded nnx.Module. + - model_params: The model parameters. + """ def _nnx_convert_and_reload() -> tuple[nnx.Module, Any]: """Converts the model to an NNX checkpoint and reloads it. @@ -157,13 +174,26 @@ def _nnx_convert_and_reload() -> tuple[nnx.Module, Any]: This is a workaround, as the checkpoints on Kaggle don't work with NNX. This takes a long time. Skip if conversion is not needed. """ - if model_name.startswith('gemma-2'): - params_path = os.path.join( - ckpt_path, model_name.replace('gemma-2', 'gemma2') + if model_path: + dir_name = os.path.basename(model_path) + else: + # If model_path is not provided, fall back to inferring from model_name + logging.warning( + 'model_path is not provided. Inferring from model_name. This may lead' + ' to incorrect results if the model_name (%s) is not a standard Gemma' + ' model name.', model_name ) - else: # gemma - suffix = '-'.join(model_name.split('-')[1:]) - params_path = os.path.join(ckpt_path, suffix) + naming_info = naming.ModelNaming(model_name=model_name) + version_dashed = naming_info.model_version.replace('_', '-') + + if naming_info.model_family == 'gemma2': + dir_name = f'gemma2-{version_dashed}' + elif naming_info.model_family == 'gemma1p1': + dir_name = f'1.1-{version_dashed}' + else: # gemma + dir_name = version_dashed + + params_path = os.path.join(ckpt_path, dir_name) model, params = create_gemma_model_from_params(params_path, model_name) @@ -419,6 +449,7 @@ def from_pretrained( intermediate_ckpt_dir=intermediate_ckpt_dir, rng_seed=rng_seed, mesh=mesh, + model_path=model_path, ) elif model_source == ModelSource.INTERNAL: model, model_params = create_gemma_model_from_params( diff --git a/tunix/models/naming.py b/tunix/models/naming.py index efab05584..112a132b2 100644 --- a/tunix/models/naming.py +++ b/tunix/models/naming.py @@ -21,17 +21,30 @@ import dataclasses +from typing import NewType import immutabledict +HFModelId = NewType('HFModelId', str) +ConfigId = NewType('ConfigId', str) + + +def _is_hf_model_id_type(model_id_or_name: str) -> bool: + return '-' in model_id_or_name or '.' in model_id_or_name + + +def _is_config_id_type(model_id_or_name: str) -> bool: + return not _is_hf_model_id_type(model_id_or_name) and '_' in model_id_or_name + + @dataclasses.dataclass(frozen=True) class ModelNaming: """Model naming information. Attributes: - model_id: The full model name identifier (case sensitive), as it appears on - Huggingface, including the parent directory. - E.g.,"meta-llama/Llama-3.1-8B". + model_id: A unique identifier for the model, which can be either a + Huggingface model ID (e.g., "meta-llama/Llama-3.1-8B") or a standardized + ConfigId (e.g., "llama3p1_8b"). model_name: The unique full name identifier of the model. This should be the full name and should match exactly with the model name used in Hugging Face. e.g., "gemma-2b","llama-3.1-8b". The model name is all lowercase and @@ -53,8 +66,9 @@ class ModelNaming: model version, used in the ModelConfig class. e.g., "gemma_2b_it" or "qwen2p5_0p5b". """ - - model_id: str | None = None + # TODO(b/451662153): use HFModelId and ConfigId throughout, add validation, + # and then remove str support. + model_id: HFModelId | ConfigId | str | None = None model_name: str | None = None model_family: str = dataclasses.field(init=False) model_version: str = dataclasses.field(init=False) @@ -97,10 +111,8 @@ class _ModelFamilyInfo: config_category: str # category in the path to the ModelConfig class -# Mapping of all model families from the hugging face model id to the internal -# model_family and config_category. Key is the prefix of the hugging face model -# id and value is the internal model family and config_category. -_MODEL_FAMILY_INFO_MAPPING = immutabledict.immutabledict({ +# HF model family info mapping. +_HF_MODEL_FAMILY_INFO_MAPPING = immutabledict.immutabledict({ 'gemma-': _ModelFamilyInfo(family='gemma', config_category='gemma'), 'gemma1.1-': _ModelFamilyInfo(family='gemma1p1', config_category='gemma'), 'gemma-1.1-': _ModelFamilyInfo(family='gemma1p1', config_category='gemma'), @@ -121,12 +133,43 @@ class _ModelFamilyInfo: ), }) +# Config id model family info mapping. +_CONFIG_ID_MODEL_FAMILY_INFO_MAPPING = immutabledict.immutabledict({ + 'gemma_': _ModelFamilyInfo(family='gemma', config_category='gemma'), + 'gemma1p1_': _ModelFamilyInfo(family='gemma1p1', config_category='gemma'), + 'gemma2_': _ModelFamilyInfo(family='gemma2', config_category='gemma'), + 'gemma3_': _ModelFamilyInfo(family='gemma3', config_category='gemma3'), + 'llama3_': _ModelFamilyInfo(family='llama3', config_category='llama3'), + 'llama3p1_': _ModelFamilyInfo(family='llama3p1', config_category='llama3'), + 'llama3p2_': _ModelFamilyInfo(family='llama3p2', config_category='llama3'), + 'qwen2p5_': _ModelFamilyInfo(family='qwen2p5', config_category='qwen2'), + 'qwen3_': _ModelFamilyInfo(family='qwen3', config_category='qwen3'), + 'deepseek_r1_distill_qwen_': _ModelFamilyInfo( + family='deepseek_r1_distill_qwen', config_category='qwen2' + ), +}) + + +def _get_model_family_mapping( + model_name: str, +) -> immutabledict.immutabledict[str, _ModelFamilyInfo]: + """Returns the model family mapping based on the model name format.""" + if _is_hf_model_id_type(model_name): + return _HF_MODEL_FAMILY_INFO_MAPPING + elif _is_config_id_type(model_name): + return _CONFIG_ID_MODEL_FAMILY_INFO_MAPPING + else: + raise ValueError( + f'Invalid model ID format: {model_name!r}. Expected a Huggingface' + ' model ID or a ConfigId.' + ) + def split(model_name: str) -> tuple[str, str]: """Splits model name into model family and model version. Find the longest matching prefix of the model name in the - _MODEL_FAMILY_INFO_MAPPING. Returns the remaining string as the model version, + model family info mapping. Returns the remaining string as the model version, stripping leading hyphens. Args: @@ -136,8 +179,9 @@ def split(model_name: str) -> tuple[str, str]: A tuple containing the un-standardized model_family and model_version. """ model_name = model_name.lower() + mapping = _get_model_family_mapping(model_name) matched_family = '' - for family in _MODEL_FAMILY_INFO_MAPPING: + for family in mapping: if model_name.startswith(family) and len(family) > len(matched_family): matched_family = family if matched_family: @@ -146,7 +190,7 @@ def split(model_name: str) -> tuple[str, str]: raise ValueError( f'Could not determine model family for: {model_name}. Not one of the' ' known families:' - f' {list(_MODEL_FAMILY_INFO_MAPPING.keys())}' + f' {list(mapping.keys())}' ) @@ -181,7 +225,8 @@ def _standardize_model_version(raw_model_version: str) -> str: def get_model_family_and_version(model_name: str) -> tuple[str, str]: """Splits model name into internal, standardized model family and model version.""" raw_model_family, raw_model_version = split(model_name) - model_family = _MODEL_FAMILY_INFO_MAPPING[raw_model_family].family + mapping = _get_model_family_mapping(model_name) + model_family = mapping[raw_model_family].family model_version = _standardize_model_version(raw_model_version) return model_family, model_version @@ -189,7 +234,8 @@ def get_model_family_and_version(model_name: str) -> tuple[str, str]: def get_model_config_category(model_name: str) -> str: """Returns the model config category from the model family.""" raw_model_family, _ = split(model_name) - return _MODEL_FAMILY_INFO_MAPPING[raw_model_family].config_category + mapping = _get_model_family_mapping(model_name) + return mapping[raw_model_family].config_category def get_model_config_id(model_name: str) -> str: @@ -200,17 +246,18 @@ def get_model_config_id(model_name: str) -> str: return config_id -def get_model_name_from_model_id(model_id: str) -> str: +def get_model_name_from_model_id(model_id: HFModelId | ConfigId | str) -> str: """Extracts model name from model ID by taking the last part of path. Args: model_id: The full model name identifier, as it appears on huggingface, - including the parent directory. E.g., meta-llama/Llama-3.1-8B. + including the parent directory. E.g., meta-llama/Llama-3.1-8B. Can also be + the model_config_id directly, e.g., llama3p1_8b. Returns: The model_name string. """ - if '/' in model_id: + if _is_hf_model_id_type(model_id) or '/' in model_id: model_name = model_id.split('/')[-1].lower() if not model_name: raise ValueError( @@ -219,8 +266,9 @@ def get_model_name_from_model_id(model_id: str) -> str: if model_name.startswith('meta-llama-'): return model_name.replace('meta-llama-', 'llama-', 1) return model_name + elif _is_config_id_type(model_id): + return model_id.lower() else: - raise ValueError( - f'Invalid model ID format: {model_id!r}. Model ID should be in the' - ' format of /' - ) + # If the model_id is not a HFModelId or ConfigId, we assume it is already + # a model_name and convert it to lowercase to be consistent. + return model_id.lower()