Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions kernels/src/kernels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,14 @@ def download_kernels(args):
for kernel_lock_json in lock_json:
kernel_lock = KernelLock.from_json(kernel_lock_json)
print(
f"Downloading `{kernel_lock.repo_id}` at with SHA: {kernel_lock.sha}",
f"Downloading `{kernel_lock.repo_id}` with SHA: {kernel_lock.sha}",
file=sys.stderr,
)
if args.all_variants:
install_kernel_all_variants(
kernel_lock.repo_id, kernel_lock.sha, variant_locks=kernel_lock.variants
kernel_lock.repo_id,
kernel_lock.sha,
variant_locks=kernel_lock.variants,
)
else:
try:
Expand Down
10 changes: 9 additions & 1 deletion kernels/src/kernels/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from kernels._versions import resolve_version_spec_as_ref
from kernels.compat import tomllib
from kernels.status import resolve_status


@dataclass
Expand Down Expand Up @@ -35,9 +36,16 @@ def get_kernel_locks(repo_id: str, version_spec: int | str) -> KernelLock:
"""
from kernels.utils import _get_hf_api

api = _get_hf_api()

# NOTE: the destination of a redirect is respected but we still use
# resolve_version_spec_as_ref to resolve the version specifier of the
# final destination repo.
repo_id, _ = resolve_status(api, repo_id, "main")

tag_for_newest = resolve_version_spec_as_ref(repo_id, version_spec)

r = _get_hf_api().repo_info(
r = api.repo_info(
repo_id=repo_id, revision=tag_for_newest.target_commit, files_metadata=True
)
if r.sha is None:
Expand Down
77 changes: 77 additions & 0 deletions kernels/src/kernels/status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import warnings
from dataclasses import dataclass
from typing import Union

from huggingface_hub import HfApi
from huggingface_hub.utils import EntryNotFoundError

from kernels.compat import tomllib


@dataclass
class Redirect:
kind: str # must be "redirect"
destination: str
revision: str

@staticmethod
def from_dict(data: dict) -> "Redirect":
if data.get("kind") != "redirect":
raise ValueError("kernel-status.toml kind must be 'redirect' for Redirect")
destination = data.get("destination")
if not destination:
raise ValueError("kernel-status.toml must contain a 'destination' field")
return Redirect(
kind="redirect",
destination=destination,
revision=data.get("revision", "main"),
)


KernelStatusKind = Union[Redirect]


class KernelStatus:
@staticmethod
def from_toml(content: str) -> KernelStatusKind:
data = tomllib.loads(content)

kind = data.get("kind")
if not kind:
raise ValueError("kernel-status.toml must contain a 'kind' field")

if kind == "redirect":
return Redirect.from_dict(data)

raise ValueError(f"Unknown kernel status kind: {kind!r}")

# Fetch the kernel status from the repository, if it exists
@staticmethod
def check_status(
api: HfApi, repo_id: str, revision: str
) -> KernelStatusKind | None:
try:
path = api.hf_hub_download(
repo_id=repo_id, filename="kernel-status.toml", revision=revision
)
with open(path, "r") as f:
return KernelStatus.from_toml(f.read())
except EntryNotFoundError:
return None


def resolve_status(api: HfApi, repo_id: str, revision: str) -> tuple[str, str]:
status = KernelStatus.check_status(api, repo_id, revision)
if status is None:
return repo_id, revision

# In the case of a redirect, return the destination repo and revision
if isinstance(status, Redirect):
warnings.warn(
f"'{repo_id}' redirected to '{status.destination}'",
UserWarning,
stacklevel=2,
)
return status.destination, status.revision

return repo_id, revision
8 changes: 6 additions & 2 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from kernels.deps import validate_dependencies
from kernels.lockfile import KernelLock, VariantLock
from kernels.metadata import Metadata
from kernels.status import resolve_status

KNOWN_BACKENDS = {"cpu", "cuda", "metal", "rocm", "xpu", "npu"}

Expand Down Expand Up @@ -231,9 +232,13 @@ def install_kernel(
Returns:
`tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
"""
api = _get_hf_api(user_agent=user_agent)

if not local_files_only:
repo_id, revision = resolve_status(api, repo_id, revision)

package_name = package_name_from_repo_id(repo_id)
allow_patterns = [f"build/{variant}/*" for variant in _build_variants(backend)]
api = _get_hf_api(user_agent=user_agent)
repo_path = Path(
str(
api.snapshot_download(
Expand Down Expand Up @@ -443,7 +448,6 @@ def has_kernel(
`bool`: `True` if a kernel is available for the current environment.
"""
revision = select_revision_or_version(repo_id, revision=revision, version=version)

package_name = package_name_from_repo_id(repo_id)

api = _get_hf_api()
Expand Down
93 changes: 93 additions & 0 deletions kernels/tests/test_status.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import pytest
from unittest.mock import MagicMock

from kernels.status import (
Redirect,
KernelStatus,
resolve_status,
)


class TestKernelStatusFromToml:
def test_simple_redirect(self):
content = '''kind = "redirect"
destination = "kernels-community/new-kernel"'''
result = KernelStatus.from_toml(content)
assert isinstance(result, Redirect)
assert result.destination == "kernels-community/new-kernel"
assert result.revision == "main"

def test_redirect_with_revision(self):
content = '''kind = "redirect"
destination = "kernels-community/new-kernel"
revision = "v2"'''
result = KernelStatus.from_toml(content)
assert isinstance(result, Redirect)
assert result.destination == "kernels-community/new-kernel"
assert result.revision == "v2"

def test_missing_kind_raises(self):
with pytest.raises(ValueError, match="must contain a 'kind' field"):
KernelStatus.from_toml('destination = "kernels-community/new-kernel"')

def test_unknown_kind_raises(self):
content = '''kind = "unknown"
destination = "kernels-community/new-kernel"'''
with pytest.raises(ValueError, match="Unknown kernel status kind"):
KernelStatus.from_toml(content)

def test_missing_destination_raises(self):
with pytest.raises(ValueError, match="must contain a 'destination'"):
KernelStatus.from_toml('kind = "redirect"')

def test_empty_content_raises(self):
with pytest.raises(ValueError, match="must contain a 'kind'"):
KernelStatus.from_toml("")


class TestResolveStatus:
def test_no_status(self):
from huggingface_hub.utils import EntryNotFoundError

mock_api = MagicMock()
mock_api.hf_hub_download.side_effect = EntryNotFoundError("Not found")

repo_id, revision = resolve_status(mock_api, "kernels-test/kernel", "main")
assert repo_id == "kernels-test/kernel"
assert revision == "main"

def test_redirect(self, tmp_path):
from huggingface_hub.utils import EntryNotFoundError

status_file = tmp_path / "kernel-status.toml"
status_file.write_text('kind = "redirect"\ndestination = "kernels-community/new-kernel"')

def mock_download(repo_id, filename, revision):
if repo_id == "kernels-test/old-kernel":
return str(status_file)
raise EntryNotFoundError("Not found")

mock_api = MagicMock()
mock_api.hf_hub_download.side_effect = mock_download

repo_id, revision = resolve_status(mock_api, "kernels-test/old-kernel", "main")
assert repo_id == "kernels-community/new-kernel"
assert revision == "main"

def test_redirect_with_revision(self, tmp_path):
from huggingface_hub.utils import EntryNotFoundError

status_file = tmp_path / "kernel-status.toml"
status_file.write_text('kind = "redirect"\ndestination = "kernels-community/new-kernel"\nrevision = "v2"')

def mock_download(repo_id, filename, revision):
if repo_id == "kernels-test/old-kernel":
return str(status_file)
raise EntryNotFoundError("Not found")

mock_api = MagicMock()
mock_api.hf_hub_download.side_effect = mock_download

repo_id, revision = resolve_status(mock_api, "kernels-test/old-kernel", "main")
assert repo_id == "kernels-community/new-kernel"
assert revision == "v2"
Loading