Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions docs/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ Adding the new model needs to following the naming convention that Tunix support
## AutoModel

`AutoModel` provides a unified interface for instantiating Tunix models from
pretrained checkpoints, similar to the Hugging Face `AutoModel` API. It allows
pretrained checkpoints, similar to the Huggingface `AutoModel` API. It allows
you to load a model simply by providing its `model_id`, handling the download
and initialization for you.

### Basic Usage

To load a model, use the `AutoModel.from_pretrained` method with the model
identifier and your JAX sharding mesh. By default this will download the model
from HuggingFace.
from Huggingface.

```python
from tunix.models.automodel import AutoModel
Expand All @@ -80,9 +80,9 @@ import jax
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)

# 2. Load the model
# By default, this downloads from Hugging Face.
# By default, this downloads from Huggingface.
model, model_path = AutoModel.from_pretrained(
model_id="google/gemma-2-2b-it",
model_id="google/gemma-2-2b-it", # Using HF id as model_id
mesh=mesh
)

Expand All @@ -94,20 +94,19 @@ print(f"Model loaded from: {model_path}")
You can load models from different sources (e.g., Kaggle, GCS, etc.) using the
`model_source` argument.

#### From HuggingFace:
#### From Huggingface:

This is the default choice (`ModelSource.HUGGINGFACE`) as shown in the
example above.

#### From Kaggle:

For Kaggle, you must provide the `model_id` which is the Hugging Face identifier
(to determine the model configuration) and the `model_path` which is the Kaggle
For Kaggle, you must provide the `model_id` which is the Huggingface identifier or model_config_id (see [Naming Conventions](models.md#naming-conventions)) to determine the model configuration and the `model_path` which is the Kaggle
Hub model identifier (used to download the model from Kaggle).

```python
model, model_path = AutoModel.from_pretrained(
model_id="google/gemma2-2b-it",
model_id="gemma2_2b_it", # Using model_config_id as model_id
mesh=mesh,
model_source=ModelSource.KAGGLE,
model_path="google/gemma-2/flax/gemma2-2b-it",
Expand All @@ -120,13 +119,12 @@ For example the `model_path` for the `google/gemma-2/flax/gemma2-2b-it` is extra

#### From GCS:

For GCS, you must provide the `model_id` which is the Hugging Face identifier
(to determine the model configuration) and the `model_path` (the actual GCS
For GCS, you must provide the `model_id` which is the Huggingface identifier or model_config_id (see [Naming Conventions](models.md#naming-conventions)) to determine the model configuration and the `model_path` (the actual GCS
location).

```python
model, model_path = AutoModel.from_pretrained(
model_id="google/gemma-2-2b-it",
model_id="gemma2_2b_it", # Using model_config_id as model_id
mesh=mesh,
model_source=ModelSource.GCS,
model_path="gs://my-bucket/gemma-2-2b-it"
Expand All @@ -139,7 +137,7 @@ Optionally, you can also provide the `model_download_path` argument, which
specifies where the model is to be downloaded to. Depending on the
`model_source` the effect of specifying this variable is different:

* **Hugging Face**: Files are downloaded directly to this directory.
* **Huggingface**: Files are downloaded directly to this directory.
* **Kaggle**: Sets the `KAGGLEHUB_CACHE` environment variable to this path.
* **GCS**: No-op.
* **Internal**: Files are copied to this directory. If omitted, the model is loaded directly from the `model_path`. This mode (Internal) is not supported in OSS version.
Expand All @@ -148,21 +146,27 @@ specifies where the model is to be downloaded to. Depending on the

This section outlines the naming conventions used within Tunix for model
identification and configuration. These conventions ensure consistency when
loading models from various sources like Hugging Face or Kaggle.
loading models from various sources like Huggingface or Kaggle.

The `ModelNaming` dataclass handles the parsing and standardization of model names.

* **`model_id`**: The full model name identifier (case sensitive), as it appears
on Hugging Face, including the parent directory. For example,
* **`model_id`**: This is a unique identifier used to identifty the model in mind and extract the family, version, and desired config from. Tunix support two identifiers as the `model_id`:
1. **Huggingface (HF) IDs:** The full model name identifier (case sensitive), as it appears
on Huggingface, including the parent directory.
* **Extracting model_id from HF**: For example,
`meta-llama/Llama-3.1-8B` is extracted as shown below:
![Hugging Face extracting Model ID](images/model_id_huggingface.png){: width="75%"}
![Huggingface extracting Model ID](images/model_id_huggingface.png){: width="75%"}

2. **Native Tunix model_configs:** the `model_config_id` representing the exact config from the model class can be used directly as the `model_id`. In this case it will also be treated as the `model_name`.
* **Extracting model_id from model_config_id**: In this case, you would need to refer to the source code (`model.py`) for each model family and select the config id from the `ModelConfig` class, for example `llama3p1_8b` from the llama [model code](https://github.com/google/tunix/blob/main/models/llama3/model.py;bpv=1;bpt=1;l=138).


* **`model_name`**: The unique full name identifier of the model. This
corresponds to the full name and should match exactly with the model name
used in Hugging Face or Kaggle. It is typically all lowercase and formatted
as `<model-family>-<model-version>`.
* *Example*: `gemma-2b`, `llama-3.1-8b`, `gemma2-2b-it`.
as `<model-family>-<model-version>` (when HF is used for model_id) or `<model-family>_<model-version>` (when model_config_id is used for model_id) .
* *Example for HF as model_id*: `gemma-2b`, `llama-3.1-8b`, `gemma-2-2b-it`.
* *Example for model_config_id as model_id*: `gemma_2b`, `llama3p1_8b`, `gemma2_2b_it`.

* **`model_family`**: The standardized model family. Unnecessary hyphens are
removed, and versions are standardized (e.g., replacing dot with `p`).
Expand Down
4 changes: 2 additions & 2 deletions examples/rl/grpo/gsm8k/configs/gemma2_2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

model_config:
model_name: "gemma-2-2b-it"
model_id: "google/gemma-2-2b-it"
model_name: "gemma2_2b_it"
model_id: "gemma2_2b_it"
model_path: "google/gemma-2/flax/gemma2-2b-it"
model_source: "kaggle"
mesh:
Expand Down
4 changes: 2 additions & 2 deletions examples/rl/grpo/gsm8k/run_gemma_7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ echo "Rounded warmup steps: $warmup_steps"

python3 -m tunix.cli.grpo_main \
base_config.yaml \
model_config.model_name="gemma-7b-it" \
model_config.model_id="google/gemma-7b-it" \
model_config.model_name="gemma_7b_it" \
model_config.model_id="gemma_7b_it" \
model_config.model_path="google/gemma/flax/7b-it" \
model_config.model_source="kaggle" \
model_config.model_download_path="/tmp/models/gemma-7b" \
Expand Down
4 changes: 2 additions & 2 deletions examples/sft/mtnt/configs/gemma2_2b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

model_config:
model_name: "gemma-2-2b-it"
model_id: "google/gemma-2-2b-it"
model_name: "gemma2_2b_it"
model_id: "gemma2_2b_it"
model_path: "google/gemma-2/flax/gemma2-2b-it"
model_source: "kaggle"
mesh:
Expand Down
4 changes: 2 additions & 2 deletions examples/sft/mtnt/run_gemma_2b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ set -x # Enable xtrace

python3 -m tunix.cli.peft_main \
base_config.yaml \
model_config.model_name="gemma-2b" \
model_config.model_id="google/gemma-2b" \
model_config.model_name="gemma_2b" \
model_config.model_id="gemma_2b" \
model_config.model_path="google/gemma/flax/2b" \
model_config.model_source="kaggle" \
model_config.model_download_path="/tmp/models" \
Expand Down
126 changes: 112 additions & 14 deletions tests/models/naming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,41 @@ def _get_test_cases_for_model_id_exists() -> list[dict[str, str]]:
]


def _get_test_cases_for_auto_population_with_HF_model_id() -> (
list[dict[str, str]]
):
test_cases = []
_validate_full_model_coverage()
for model_info in _TEST_MODEL_INFOS:
test_cases.append({
'testcase_name': model_info.config_id,
'model_id': model_info.id,
'expected_name': model_info.name,
'expected_family': model_info.family,
'expected_version': model_info.version,
'expected_category': model_info.category,
'expected_config_id': model_info.config_id,
})
return test_cases


def _get_test_cases_for_auto_population_with_config_id() -> (
list[dict[str, str]]
):
test_cases = []
_validate_full_model_coverage()
for model_info in _TEST_MODEL_INFOS:
test_cases.append({
'testcase_name': model_info.config_id,
'model_id': model_info.config_id,
'expected_family': model_info.family,
'expected_version': model_info.version,
'expected_category': model_info.category,
'expected_config_id': model_info.config_id,
})
return test_cases


class TestNaming(parameterized.TestCase):

@parameterized.named_parameters(
Expand All @@ -501,9 +536,15 @@ def test_get_model_name_from_model_id(
expected_name,
)

def test_get_model_name_from_model_id_invalid_fails(self):
with self.assertRaisesRegex(ValueError, 'Invalid model ID format'):
naming.get_model_name_from_model_id('Llama-3.1-8B')
def test_get_model_name_from_model_id_no_slash_succeeds(self):
self.assertEqual(
naming.get_model_name_from_model_id('Llama-3.1-8B'), 'llama-3.1-8b'
)

def test_get_model_name_from_model_id_config_id(self):
self.assertEqual(
naming.get_model_name_from_model_id('llama3p1_8b'), 'llama3p1_8b'
)

def test_get_model_name_from_model_id_nested_path(self):
self.assertEqual(
Expand Down Expand Up @@ -544,7 +585,15 @@ def test_get_model_family_and_version(

def test_get_model_family_and_version_invalid_fails(self):
with self.assertRaisesRegex(
ValueError, 'Could not determine model family for: foobar.'
ValueError, 'Could not determine model family for: foo-bar.'
):
naming.get_model_family_and_version('foo-bar')

def test_get_model_family_and_version_invalid_format_fails(self):
with self.assertRaisesRegex(
ValueError,
'Invalid model ID format: .* Expected a Huggingface model ID or a'
' ConfigId.',
):
naming.get_model_family_and_version('foobar')

Expand All @@ -555,6 +604,8 @@ def test_get_model_family_and_version_invalid_version_fails(self):
def test_split(self):
self.assertEqual(naming.split('gemma-7b'), ('gemma-', '7b'))
self.assertEqual(naming.split('gemma-1.1-7b'), ('gemma-1.1-', '7b'))
self.assertEqual(naming.split('gemma_7b'), ('gemma_', '7b'))
self.assertEqual(naming.split('gemma1p1_7b'), ('gemma1p1_', '7b'))

@parameterized.named_parameters(_get_test_cases_for_get_model_config_id())
def test_get_model_config_id(self, model_name: str, expected_config_id: str):
Expand All @@ -570,15 +621,60 @@ def test_get_model_config_category(
naming.get_model_config_category(model_name), expected_category
)

def test_model_naming_auto_population(self):
model_id = 'meta-llama/Llama-3.1-8B'
naming_info = naming.ModelNaming(model_id=model_id)
self.assertEqual(naming_info.model_id, model_id)
self.assertEqual(naming_info.model_name, 'llama-3.1-8b')
self.assertEqual(naming_info.model_family, 'llama3p1')
self.assertEqual(naming_info.model_version, '8b')
self.assertEqual(naming_info.model_config_category, 'llama3')
self.assertEqual(naming_info.model_config_id, 'llama3p1_8b')
@parameterized.named_parameters(
_get_test_cases_for_auto_population_with_HF_model_id()
)
def test_model_naming_auto_population_with_HF_model_id(
self,
*,
model_id: str,
expected_name: str,
expected_family: str,
expected_version: str,
expected_category: str,
expected_config_id: str,
):
with self.subTest(name='Test Model naming creation with HFModelId'):
naming_info = naming.ModelNaming(model_id=naming.HFModelId(model_id))
self.assertEqual(naming_info.model_id, model_id)
self.assertEqual(naming_info.model_name, expected_name)
self.assertEqual(naming_info.model_family, expected_family)
self.assertEqual(naming_info.model_version, expected_version)
self.assertEqual(naming_info.model_config_category, expected_category)
self.assertEqual(naming_info.model_config_id, expected_config_id)

with self.subTest(name='Test Model id type detection'):
self.assertTrue(naming._is_hf_model_id_type(naming_info.model_id))
self.assertFalse(naming._is_config_id_type(naming_info.model_id))
self.assertTrue(naming._is_hf_model_id_type(naming_info.model_name))
self.assertFalse(naming._is_config_id_type(naming_info.model_name))

@parameterized.named_parameters(
_get_test_cases_for_auto_population_with_config_id()
)
def test_model_naming_auto_population_with_config_id_model_id(
self,
*,
model_id: str,
expected_family: str,
expected_version: str,
expected_category: str,
expected_config_id: str,
):
with self.subTest(name='Test Model naming creation with ConfigId'):
naming_info = naming.ModelNaming(model_id=naming.ConfigId(model_id))
self.assertEqual(naming_info.model_id, model_id)
self.assertEqual(naming_info.model_name, model_id)
self.assertEqual(naming_info.model_family, expected_family)
self.assertEqual(naming_info.model_version, expected_version)
self.assertEqual(naming_info.model_config_category, expected_category)
self.assertEqual(naming_info.model_config_id, expected_config_id)

with self.subTest(name='Test Model id type detection'):
self.assertFalse(naming._is_hf_model_id_type(naming_info.model_id))
self.assertTrue(naming._is_config_id_type(naming_info.model_id))
self.assertFalse(naming._is_hf_model_id_type(naming_info.model_name))
self.assertTrue(naming._is_config_id_type(naming_info.model_name))

def test_model_naming_no_model_id(self):
model_name = 'gemma-2b'
Expand Down Expand Up @@ -608,7 +704,9 @@ def test_model_naming_mismatch(self):
'model_name set in ModelNaming and one inferred from model_id do not'
' match',
):
naming.ModelNaming(model_name='gemma-7b', model_id='google/gemma-2b')
naming.ModelNaming(
model_name='gemma-7b', model_id=naming.HFModelId('google/gemma-2b')
)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions tests/smoke_tests/model_creation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ def tearDown(self):
),
dict(
testcase_name="gemma2_2b_it",
model_name="gemma-2-2b-it",
model_name="gemma2_2b_it",
model_source="kaggle",
model_id="google/gemma-2-2b-it",
model_id="gemma2_2b_it",
model_path="google/gemma-2/flax/gemma2-2b-it",
tokenizer_path=model._DEFAULT_TOKENIZER_PATH,
tokenizer_type="sentencepiece",
Expand Down
Loading