diff --git a/kernels/src/kernels/cli/__init__.py b/kernels/src/kernels/cli/__init__.py index 8d84df48..86dd0aa5 100644 --- a/kernels/src/kernels/cli/__init__.py +++ b/kernels/src/kernels/cli/__init__.py @@ -297,12 +297,14 @@ def download_kernels(args): for kernel_lock_json in lock_json: kernel_lock = KernelLock.from_json(kernel_lock_json) print( - f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}", + f"Downloading `{kernel_lock.repo_id}` with SHA: {kernel_lock.sha}", file=sys.stderr, ) if args.all_variants: install_kernel_all_variants( - kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants + kernel_lock.repo_id, + kernel_lock.sha, + variant_locks=kernel_lock.variants, ) else: try: diff --git a/kernels/src/kernels/lockfile.py b/kernels/src/kernels/lockfile.py index 58cf78b3..2aca2eef 100644 --- a/kernels/src/kernels/lockfile.py +++ b/kernels/src/kernels/lockfile.py @@ -4,6 +4,7 @@ from kernels._versions import resolve_version_spec_as_ref from kernels.compat import tomllib +from kernels.status import resolve_status @dataclass @@ -35,9 +36,16 @@ def get_kernel_locks(repo_id: str, version_spec: int | str) -> KernelLock: """ from kernels.utils import _get_hf_api + api = _get_hf_api() + + # NOTE: the destination of a redirect is respected but we still use + # resolve_version_spec_as_ref to resolve the version specifier of the + # final destination repo. + repo_id, _ = resolve_status(api, repo_id, "main") + tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec) - r = _get_hf_api().repo_info( + r = api.repo_info( repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True ) if r.sha is None: diff --git a/kernels/src/kernels/status.py b/kernels/src/kernels/status.py new file mode 100644 index 00000000..f22cc099 --- /dev/null +++ b/kernels/src/kernels/status.py @@ -0,0 +1,77 @@ +import warnings +from dataclasses import dataclass +from typing import Union + +from huggingface_hub import HfApi +from huggingface_hub.utils import EntryNotFoundError + +from kernels.compat import tomllib + + +@dataclass +class Redirect: + kind: str # must be "redirect" + destination: str + revision: str + + @staticmethod + def from_dict(data: dict) -> "Redirect": + if data.get("kind") != "redirect": + raise ValueError("kernel-status.toml kind must be 'redirect' for Redirect") + destination = data.get("destination") + if not destination: + raise ValueError("kernel-status.toml must contain a 'destination' field") + return Redirect( + kind="redirect", + destination=destination, + revision=data.get("revision", "main"), + ) + + +KernelStatusKind = Union[Redirect] + + +class KernelStatus: + @staticmethod + def from_toml(content: str) -> KernelStatusKind: + data = tomllib.loads(content) + + kind = data.get("kind") + if not kind: + raise ValueError("kernel-status.toml must contain a 'kind' field") + + if kind == "redirect": + return Redirect.from_dict(data) + + raise ValueError(f"Unknown kernel status kind: {kind!r}") + + # Fetch the kernel status from the repository, if it exists + @staticmethod + def check_status( + api: HfApi, repo_id: str, revision: str + ) -> KernelStatusKind | None: + try: + path = api.hf_hub_download( + repo_id=repo_id, filename="kernel-status.toml", revision=revision + ) + with open(path, "r") as f: + return KernelStatus.from_toml(f.read()) + except EntryNotFoundError: + return None + + +def resolve_status(api: HfApi, repo_id: str, revision: str) -> tuple[str, str]: + status = KernelStatus.check_status(api, repo_id, revision) + if status is None: + return repo_id, revision + + # In the case of a redirect, return the destination repo and revision + if isinstance(status, Redirect): + warnings.warn( + f"'{repo_id}' redirected to '{status.destination}'", + UserWarning, + stacklevel=2, + ) + return status.destination, status.revision + + return repo_id, revision diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 7ee546f1..ec8e3a33 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -21,6 +21,7 @@ from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock from kernels.metadata import Metadata +from kernels.status import resolve_status KNOWN_BACKENDS = {"cpu", "cuda", "metal", "rocm", "xpu", "npu"} @@ -231,9 +232,13 @@ def install_kernel( Returns: `tuple[str, Path]`: A tuple containing the package name and the path to the variant directory. """ + api = _get_hf_api(user_agent=user_agent) + + if not local_files_only: + repo_id, revision = resolve_status(api, repo_id, revision) + package_name = package_name_from_repo_id(repo_id) allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)] - api = _get_hf_api(user_agent=user_agent) repo_path = Path( str( api.snapshot_download( @@ -443,7 +448,6 @@ def has_kernel( `bool`: `True` if a kernel is available for the current environment. """ revision = select_revision_or_version(repo_id, revision=revision, version=version) - package_name = package_name_from_repo_id(repo_id) api = _get_hf_api() diff --git a/kernels/tests/test_status.py b/kernels/tests/test_status.py new file mode 100644 index 00000000..d9d5bd2b --- /dev/null +++ b/kernels/tests/test_status.py @@ -0,0 +1,93 @@ +import pytest +from unittest.mock import MagicMock + +from kernels.status import ( + Redirect, + KernelStatus, + resolve_status, +) + + +class TestKernelStatusFromToml: + def test_simple_redirect(self): + content = '''kind = "redirect" +destination = "kernels-community/new-kernel"''' + result = KernelStatus.from_toml(content) + assert isinstance(result, Redirect) + assert result.destination == "kernels-community/new-kernel" + assert result.revision == "main" + + def test_redirect_with_revision(self): + content = '''kind = "redirect" +destination = "kernels-community/new-kernel" +revision = "v2"''' + result = KernelStatus.from_toml(content) + assert isinstance(result, Redirect) + assert result.destination == "kernels-community/new-kernel" + assert result.revision == "v2" + + def test_missing_kind_raises(self): + with pytest.raises(ValueError, match="must contain a 'kind' field"): + KernelStatus.from_toml('destination = "kernels-community/new-kernel"') + + def test_unknown_kind_raises(self): + content = '''kind = "unknown" +destination = "kernels-community/new-kernel"''' + with pytest.raises(ValueError, match="Unknown kernel status kind"): + KernelStatus.from_toml(content) + + def test_missing_destination_raises(self): + with pytest.raises(ValueError, match="must contain a 'destination'"): + KernelStatus.from_toml('kind = "redirect"') + + def test_empty_content_raises(self): + with pytest.raises(ValueError, match="must contain a 'kind'"): + KernelStatus.from_toml("") + + +class TestResolveStatus: + def test_no_status(self): + from huggingface_hub.utils import EntryNotFoundError + + mock_api = MagicMock() + mock_api.hf_hub_download.side_effect = EntryNotFoundError("Not found") + + repo_id, revision = resolve_status(mock_api, "kernels-test/kernel", "main") + assert repo_id == "kernels-test/kernel" + assert revision == "main" + + def test_redirect(self, tmp_path): + from huggingface_hub.utils import EntryNotFoundError + + status_file = tmp_path / "kernel-status.toml" + status_file.write_text('kind = "redirect"\ndestination = "kernels-community/new-kernel"') + + def mock_download(repo_id, filename, revision): + if repo_id == "kernels-test/old-kernel": + return str(status_file) + raise EntryNotFoundError("Not found") + + mock_api = MagicMock() + mock_api.hf_hub_download.side_effect = mock_download + + repo_id, revision = resolve_status(mock_api, "kernels-test/old-kernel", "main") + assert repo_id == "kernels-community/new-kernel" + assert revision == "main" + + def test_redirect_with_revision(self, tmp_path): + from huggingface_hub.utils import EntryNotFoundError + + status_file = tmp_path / "kernel-status.toml" + status_file.write_text('kind = "redirect"\ndestination = "kernels-community/new-kernel"\nrevision = "v2"') + + def mock_download(repo_id, filename, revision): + if repo_id == "kernels-test/old-kernel": + return str(status_file) + raise EntryNotFoundError("Not found") + + mock_api = MagicMock() + mock_api.hf_hub_download.side_effect = mock_download + + repo_id, revision = resolve_status(mock_api, "kernels-test/old-kernel", "main") + assert repo_id == "kernels-community/new-kernel" + assert revision == "v2"