diff --git a/kernels/src/kernels/cli/upload.py b/kernels/src/kernels/cli/upload.py index 52f37ff..c1aacf5 100644 --- a/kernels/src/kernels/cli/upload.py +++ b/kernels/src/kernels/cli/upload.py @@ -1,9 +1,66 @@ from pathlib import Path +from huggingface_hub import CommitOperationAdd, CommitOperationDelete +from huggingface_hub.utils import chunk_iterable + from kernels.metadata import Metadata from kernels.utils import _get_hf_api from kernels.variants import BUILD_VARIANT_REGEX +BUILD_COMMIT_BATCH_SIZE = 1_000 + + +def _upload_build_dir( + api, + *, + repo_id: str, + revision: str | None, + build_dir: Path, + variants: list[Path], +): + repo_paths = {} + for variant in variants: + for path in sorted(variant.rglob("*")): + if path.is_file(): + repo_paths[f"build/{path.relative_to(build_dir).as_posix()}"] = path + + variant_prefixes = tuple( + f"build/{variant.relative_to(build_dir).as_posix()}/" for variant in variants + ) + operations = [ + CommitOperationDelete(path_in_repo=repo_file) + for repo_file in sorted( + api.list_repo_files(repo_id=repo_id, revision=revision, repo_type="model") + ) + if repo_file.startswith(variant_prefixes) and repo_file not in repo_paths + ] + operations.extend( + CommitOperationAdd(path_in_repo=repo_path, path_or_fileobj=str(local_path)) + for repo_path, local_path in sorted(repo_paths.items()) + ) + + batch_count = (len(operations) + BUILD_COMMIT_BATCH_SIZE - 1) // BUILD_COMMIT_BATCH_SIZE + if batch_count > 1: + print( + f"⚠️ Found {len(operations)} build operations, uploading in {batch_count} commits." + ) + + for batch_index, chunk in enumerate( + chunk_iterable(operations, chunk_size=BUILD_COMMIT_BATCH_SIZE), start=1 + ): + commit_message = "Build uploaded using `kernels`." + if batch_count > 1: + commit_message = ( + f"Build uploaded using `kernels` (batch {batch_index}/{batch_count})." + ) + api.create_commit( + repo_id=repo_id, + operations=list(chunk), + revision=revision, + repo_type="model", + commit_message=commit_message, + ) + def upload_kernels_dir( kernel_dir: Path, @@ -53,12 +110,7 @@ def upload_kernels_dir( if branch is not None: api.create_branch(repo_id=repo_id, branch=branch, exist_ok=True) - delete_patterns: set[str] = set() - for build_variant in build_dir.iterdir(): - if build_variant.is_dir(): - delete_patterns.add(f"{build_variant.name}/**") - - # in the case we have variants, upload to the same as the kernel_dir + # In the case we have benchmarks, upload to the same repo as the kernel_dir. if (kernel_dir / "benchmarks").is_dir(): api.upload_folder( repo_id=repo_id, @@ -70,13 +122,13 @@ def upload_kernels_dir( allow_patterns=["benchmark*.py"], ) - api.upload_folder( + assert variants is not None + _upload_build_dir( + api, repo_id=repo_id, - folder_path=build_dir, revision=branch, - path_in_repo="build", - delete_patterns=list(delete_patterns), - commit_message="Build uploaded using `kernels`.", - allow_patterns=["torch*"], + build_dir=build_dir, + variants=variants, ) + print(f"✅ Kernel upload successful. Find the kernel in: https://hf.co/{repo_id}") diff --git a/kernels/tests/test_kernel_upload.py b/kernels/tests/test_kernel_upload.py index 11928bf..001b819 100644 --- a/kernels/tests/test_kernel_upload.py +++ b/kernels/tests/test_kernel_upload.py @@ -4,10 +4,14 @@ import tempfile from dataclasses import dataclass from pathlib import Path +from types import SimpleNamespace +from unittest.mock import Mock import pytest +from huggingface_hub import CommitOperationAdd, CommitOperationDelete from kernels.cli import upload_kernels +from kernels.cli.upload import BUILD_COMMIT_BATCH_SIZE from kernels.utils import _get_hf_api REPO_ID = "valid_org/kernels-upload-test" @@ -120,3 +124,64 @@ def test_kernel_upload_deletes_as_expected(): str(filename_to_change) in k for k in repo_filenames ), f"{repo_filenames=}" _get_hf_api().delete_repo(repo_id=REPO_ID) + + +def test_large_kernel_upload_uses_create_commit_batches(monkeypatch, tmp_path): + kernel_root = tmp_path / "kernel" + build_variant = kernel_root / "build" / "torch-cpu" + build_variant.mkdir(parents=True, exist_ok=True) + (build_variant / "metadata.json").write_text("{}") + file_count = BUILD_COMMIT_BATCH_SIZE * 2 + for i in range(file_count): + (build_variant / f"file_{i}.py").touch() + + api = Mock() + api.create_repo.return_value = SimpleNamespace(repo_id=REPO_ID) + api.list_repo_files.return_value = [ + "README.md", + "build/torch-cpu/file_0.py", + "build/torch-cpu/stale.py", + "build/torch-cuda/keep.py", + ] + monkeypatch.setattr("kernels.cli.upload._get_hf_api", lambda: api) + + upload_kernels(UploadArgs(kernel_root, REPO_ID, False, "main")) + + # 2 full batches of adds, plus metadata and 1 stale-file delete. + assert api.create_commit.call_count == 3 + batch_sizes = [ + len(call.kwargs["operations"]) for call in api.create_commit.call_args_list + ] + assert batch_sizes == [ + BUILD_COMMIT_BATCH_SIZE, + BUILD_COMMIT_BATCH_SIZE, + 2, + ] + commit_messages = [ + call.kwargs["commit_message"] for call in api.create_commit.call_args_list + ] + assert commit_messages == [ + "Build uploaded using `kernels` (batch 1/3).", + "Build uploaded using `kernels` (batch 2/3).", + "Build uploaded using `kernels` (batch 3/3).", + ] + + # Stale repo files should be deleted. + operations = [ + operation + for call in api.create_commit.call_args_list + for operation in call.kwargs["operations"] + ] + delete_paths = { + op.path_in_repo for op in operations if isinstance(op, CommitOperationDelete) + } + assert delete_paths == {"build/torch-cpu/stale.py"} + + add_paths = { + op.path_in_repo for op in operations if isinstance(op, CommitOperationAdd) + } + assert len(add_paths) == file_count + 1 + assert "build/torch-cpu/metadata.json" in add_paths + assert "build/torch-cpu/file_0.py" in add_paths + assert "build/torch-cpu/file_399.py" in add_paths + api.upload_folder.assert_not_called()