-
Notifications
You must be signed in to change notification settings - Fork 5
Fix cu128 index selection for pinned Torch versions #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,131 @@ | |
| CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") | ||
| ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") | ||
|
|
||
| _CUDA_12_8_PLATFORM = "cu128" | ||
| _CUDA_12_4_PLATFORM = "cu124" | ||
| _CUDA_12_8_MIN_VERSIONS = { | ||
| "torch": (2, 7, 0), | ||
| "torchaudio": (2, 7, 0), | ||
| "torchvision": (0, 22, 0), | ||
| } | ||
|
|
||
|
|
||
| def _parse_version_segments(text): | ||
| text = text.strip().split("+", 1)[0] | ||
| segments = [] | ||
| for part in text.split("."): | ||
| m = re.match(r"^(\d+)", part) | ||
| if not m: | ||
| break | ||
| segments.append(int(m.group(1))) | ||
| return segments | ||
|
|
||
|
|
||
| def _as_version_tuple(version_segments): | ||
| padded = list(version_segments[:3]) | ||
| while len(padded) < 3: | ||
| padded.append(0) | ||
| return tuple(padded) | ||
|
|
||
|
|
||
| def _version_lt(a, b): | ||
| return _as_version_tuple(a) < _as_version_tuple(b) | ||
|
|
||
|
|
||
| def _version_le(a, b): | ||
| return _as_version_tuple(a) <= _as_version_tuple(b) | ||
|
|
||
|
|
||
| def _get_requirement_name_and_specifier(requirement): | ||
| req = requirement.strip() | ||
| if not req or req.startswith("-") or "@" in req: | ||
| return None, None | ||
|
|
||
| match = re.match(r"^([A-Za-z0-9][A-Za-z0-9_.-]*)(?:\[[^\]]+\])?", req) | ||
| if not match: | ||
| return None, None | ||
|
|
||
| name = match.group(1).lower().replace("_", "-") | ||
| spec = req[match.end() :].split(";", 1)[0].strip() | ||
| return name, spec | ||
|
|
||
|
|
||
| def _upper_bound_for_specifier(specifier): | ||
| """ | ||
| Returns (upper_bound_segments, is_inclusive) for specifiers that impose an upper bound, | ||
| or (None, None) if there is no upper bound. | ||
| """ | ||
|
|
||
| s = specifier.strip() | ||
|
|
||
| if s.startswith("=="): | ||
| value = s[2:].strip() | ||
| if "*" in value: | ||
| prefix = value.split("*", 1)[0].rstrip(".") | ||
| prefix_segments = _parse_version_segments(prefix) | ||
| if not prefix_segments: | ||
| return None, None | ||
| upper = list(prefix_segments) | ||
| upper[-1] += 1 | ||
| upper.append(0) | ||
| return upper, False | ||
|
|
||
| return _parse_version_segments(value), True | ||
|
|
||
| if s.startswith("<="): | ||
| return _parse_version_segments(s[2:].strip()), True | ||
|
|
||
| if s.startswith("<"): | ||
| return _parse_version_segments(s[1:].strip()), False | ||
|
|
||
| if s.startswith("~="): | ||
| value_segments = _parse_version_segments(s[2:].strip()) | ||
| if len(value_segments) < 2: | ||
| return None, None | ||
| upper = list(value_segments[:-1]) | ||
| upper[-1] += 1 | ||
| upper.append(0) | ||
| return upper, False | ||
|
|
||
| return None, None | ||
|
|
||
|
|
||
| def _packages_require_cuda_12_4(packages): | ||
| """ | ||
| True if the requested torch package versions cannot be satisfied by the CUDA 12.8 wheel index. | ||
|
|
||
| This happens when a package is pinned (or capped) below the first version that has CUDA 12.8 wheels. | ||
| """ | ||
|
|
||
| if not packages: | ||
| return False | ||
|
|
||
| for package in packages: | ||
| name, spec = _get_requirement_name_and_specifier(package) | ||
| if not name or name not in _CUDA_12_8_MIN_VERSIONS or not spec: | ||
| continue | ||
|
|
||
| threshold = _CUDA_12_8_MIN_VERSIONS[name] | ||
| for raw in spec.split(","): | ||
| upper, inclusive = _upper_bound_for_specifier(raw) | ||
| if not upper: | ||
| continue | ||
|
|
||
| if inclusive: | ||
| if _version_lt(upper, threshold): | ||
| return True | ||
| else: | ||
| if _version_le(upper, threshold): | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
|
|
||
| def _adjust_cuda_platform_for_requested_packages(torch_platform, packages): | ||
| if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages): | ||
| return _CUDA_12_4_PLATFORM | ||
| return torch_platform | ||
|
|
||
|
|
||
| def get_install_commands(torch_platform, packages): | ||
| """ | ||
|
|
@@ -99,6 +224,7 @@ def install(packages=[], use_uv=False): | |
|
|
||
| gpu_infos = get_gpus() | ||
| torch_platform = get_torch_platform(gpu_infos) | ||
| torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages) | ||
|
||
| cmds = get_install_commands(torch_platform, packages) | ||
| cmds = get_pip_commands(cmds, use_uv=use_uv) | ||
| run_commands(cmds) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need to write a complete version parser ourselves? Isn't there any built-in library inside python that can do this?
This feels like a lot of lines of code just for this purpose.