Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion tests/test_installer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sys
import pytest
from unittest.mock import patch
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands
from torchruntime.installer import get_install_commands, get_pip_commands, run_commands, install


def test_empty_args():
Expand Down Expand Up @@ -125,3 +125,25 @@ def test_run_commands():
# Check that subprocess.run was called with the correct arguments
mock_run.assert_any_call(cmds[0])
mock_run.assert_any_call(cmds[1])


def test_install_demotes_cu128_to_cu124_for_torch_2_6(monkeypatch):
# Simulate a system where the detected platform would be cu128.
monkeypatch.setattr("torchruntime.installer.get_gpus", lambda: ["dummy_gpu"])
monkeypatch.setattr("torchruntime.installer.get_torch_platform", lambda gpu_infos: "cu128")

seen = {}

def fake_get_install_commands(torch_platform, packages):
seen["torch_platform"] = torch_platform
seen["packages"] = packages
return [packages]

monkeypatch.setattr("torchruntime.installer.get_install_commands", fake_get_install_commands)
monkeypatch.setattr("torchruntime.installer.get_pip_commands", lambda cmds, use_uv=False: cmds)
monkeypatch.setattr("torchruntime.installer.run_commands", lambda cmds: None)

install(["torch==2.6.0"])

assert seen["packages"] == ["torch==2.6.0"]
assert seen["torch_platform"] == "cu124"
126 changes: 126 additions & 0 deletions torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,131 @@
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")

_CUDA_12_8_PLATFORM = "cu128"
_CUDA_12_4_PLATFORM = "cu124"
_CUDA_12_8_MIN_VERSIONS = {
"torch": (2, 7, 0),
"torchaudio": (2, 7, 0),
"torchvision": (0, 22, 0),
}


def _parse_version_segments(text):
text = text.strip().split("+", 1)[0]
segments = []
for part in text.split("."):
m = re.match(r"^(\d+)", part)
if not m:
break
segments.append(int(m.group(1)))
return segments


def _as_version_tuple(version_segments):
padded = list(version_segments[:3])
while len(padded) < 3:
padded.append(0)
return tuple(padded)


def _version_lt(a, b):
return _as_version_tuple(a) < _as_version_tuple(b)


def _version_le(a, b):
return _as_version_tuple(a) <= _as_version_tuple(b)


def _get_requirement_name_and_specifier(requirement):
req = requirement.strip()
if not req or req.startswith("-") or "@" in req:
return None, None

match = re.match(r"^([A-Za-z0-9][A-Za-z0-9_.-]*)(?:\[[^\]]+\])?", req)
if not match:
return None, None

name = match.group(1).lower().replace("_", "-")
spec = req[match.end() :].split(";", 1)[0].strip()
return name, spec


def _upper_bound_for_specifier(specifier):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need to write a complete version parser ourselves? Isn't there any built-in library inside python that can do this?

This feels like a lot of lines of code just for this purpose.

"""
Returns (upper_bound_segments, is_inclusive) for specifiers that impose an upper bound,
or (None, None) if there is no upper bound.
"""

s = specifier.strip()

if s.startswith("=="):
value = s[2:].strip()
if "*" in value:
prefix = value.split("*", 1)[0].rstrip(".")
prefix_segments = _parse_version_segments(prefix)
if not prefix_segments:
return None, None
upper = list(prefix_segments)
upper[-1] += 1
upper.append(0)
return upper, False

return _parse_version_segments(value), True

if s.startswith("<="):
return _parse_version_segments(s[2:].strip()), True

if s.startswith("<"):
return _parse_version_segments(s[1:].strip()), False

if s.startswith("~="):
value_segments = _parse_version_segments(s[2:].strip())
if len(value_segments) < 2:
return None, None
upper = list(value_segments[:-1])
upper[-1] += 1
upper.append(0)
return upper, False

return None, None


def _packages_require_cuda_12_4(packages):
"""
True if the requested torch package versions cannot be satisfied by the CUDA 12.8 wheel index.

This happens when a package is pinned (or capped) below the first version that has CUDA 12.8 wheels.
"""

if not packages:
return False

for package in packages:
name, spec = _get_requirement_name_and_specifier(package)
if not name or name not in _CUDA_12_8_MIN_VERSIONS or not spec:
continue

threshold = _CUDA_12_8_MIN_VERSIONS[name]
for raw in spec.split(","):
upper, inclusive = _upper_bound_for_specifier(raw)
if not upper:
continue

if inclusive:
if _version_lt(upper, threshold):
return True
else:
if _version_le(upper, threshold):
return True

return False


def _adjust_cuda_platform_for_requested_packages(torch_platform, packages):
if torch_platform == _CUDA_12_8_PLATFORM and _packages_require_cuda_12_4(packages):
return _CUDA_12_4_PLATFORM
return torch_platform


def get_install_commands(torch_platform, packages):
"""
Expand Down Expand Up @@ -99,6 +224,7 @@ def install(packages=[], use_uv=False):

gpu_infos = get_gpus()
torch_platform = get_torch_platform(gpu_infos)
torch_platform = _adjust_cuda_platform_for_requested_packages(torch_platform, packages)
Copy link
Contributor

@cmdr2 cmdr2 Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please move this logic into platform_detection.py, and add an optional packages=[] arg in get_torch_platform()?

That would preserve the separation of concerns, since we're effectively fixing a problem with platform detection, not installation commands.

The unit tests would change accordingly.

cmds = get_install_commands(torch_platform, packages)
cmds = get_pip_commands(cmds, use_uv=use_uv)
run_commands(cmds)