Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,17 @@ import torchruntime
You can use the command line:
`python -m torchruntime install <optional list of package names and versions>`

CLI flags: `--policy <compat|stable|preview>`, `--preview`, `--no-unsupported`, `--uv`

Or you can use the library:
```py
torchruntime.install(["torch", "torchvision<0.20"])
```

Optional flags:
- `preview=True` to allow preview builds (e.g. ROCm 6.4, nightly builds, XPU test index)
- `unsupported=False` to forbid EOL/unsupported builds (e.g. Torch-DirectML, IPEX, Torch 1.x)

On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).

## Test torch
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the approp

**Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv`

Build-selection options:
- `--policy <name>`: `compat` (default), `stable`, `preview` (or `nightly`)
- Overrides: `--preview`, `--no-unsupported`

### Step 2. Configure torch
This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used.

Expand Down
45 changes: 35 additions & 10 deletions tests/test_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,26 @@ def test_cuda_platform_windows_installs_triton(monkeypatch):
def test_cuda_nightly_platform_linux(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
result = get_install_commands("nightly/cu112", packages, preview=True)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url]]


def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
result = get_install_commands("nightly/cu112", packages, preview=True)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]


def test_cuda_nightly_platform_requires_preview(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
with pytest.raises(ValueError, match="preview"):
get_install_commands("nightly/cu112", packages, preview=False)


def test_rocm_4_platform_does_not_install_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
Expand All @@ -71,25 +78,43 @@ def test_rocm_6_platform_linux_installs_triton(monkeypatch):
def test_xpu_platform_windows_with_torch_only(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch"]
result = get_install_commands("xpu", packages)
expected_url = "https://download.pytorch.org/whl/test/xpu"
result = get_install_commands("xpu", packages, preview=False)
expected_url = "https://download.pytorch.org/whl/xpu"
assert result == [packages + ["--index-url", expected_url]]


def test_xpu_platform_windows_with_torchvision(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages, preview=False)
expected_url = "https://download.pytorch.org/whl/xpu"
assert result == [packages + ["--index-url", expected_url]]


def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys):
def test_xpu_platform_windows_preview(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages)
expected_url = "https://download.pytorch.org/whl/nightly/xpu"
result = get_install_commands("xpu", packages, preview=True)
expected_url = "https://download.pytorch.org/whl/test/xpu"
assert result == [packages + ["--index-url", expected_url]]
captured = capsys.readouterr()
assert "[WARNING]" in captured.out


def test_xpu_platform_linux(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages)
result = get_install_commands("xpu", packages, preview=False)
expected_url = "https://download.pytorch.org/whl/xpu"
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
packages + ["--index-url", expected_url],
["pytorch-triton-xpu", "--index-url", triton_index_url],
]


def test_xpu_platform_linux_preview(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages, preview=True)
expected_url = "https://download.pytorch.org/whl/test/xpu"
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
Expand Down
11 changes: 10 additions & 1 deletion tests/test_platform_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_amd_gpu_navi4_linux(monkeypatch):
with pytest.raises(NotImplementedError):
get_torch_platform(gpu_infos)
else:
assert get_torch_platform(gpu_infos) == "rocm6.4"
assert get_torch_platform(gpu_infos) == "cpu"
assert get_torch_platform(gpu_infos, preview=True) == "rocm6.4"


def test_amd_gpu_navi3_linux(monkeypatch, capsys):
Expand Down Expand Up @@ -89,6 +90,14 @@ def test_amd_gpu_ellesmere_linux(monkeypatch):
assert get_torch_platform(gpu_infos) == "rocm4.2"


def test_amd_gpu_ellesmere_linux_unsupported_false_raises(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Ellesmere", True)]
with pytest.raises(ValueError, match="End-of-Life"):
get_torch_platform(gpu_infos, unsupported=False)


def test_amd_gpu_unsupported_linux(monkeypatch, capsys):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
Expand Down
81 changes: 81 additions & 0 deletions tests/test_policy_parsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
from torchruntime.utils.args import parse_policy_args


def test_default_policy():
args = ["pkg1"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is False
assert unsupported is True
assert cleaned == ["pkg1"]


def test_stable_policy():
args = ["--policy", "stable", "pkg1"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is False
assert unsupported is False
assert cleaned == ["pkg1"]


def test_nightly_policy():
args = ["--policy", "nightly"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is True
assert unsupported is True
assert cleaned == []


def test_preview_policy_alias():
args = ["--policy", "preview"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is True
assert unsupported is True
assert cleaned == []


def test_policy_equals_syntax():
args = ["--policy=stable", "pkg1"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is False
assert unsupported is False
assert cleaned == ["pkg1"]


def test_policy_override_preview():
# stable is p=F, u=F. --preview should make p=T
args = ["--policy", "stable", "--preview"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is True
assert unsupported is False

def test_policy_override_unsupported():
# nightly is p=T, u=T. --no-unsupported should make u=F
args = ["--policy", "nightly", "--no-unsupported"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is True
assert unsupported is False


def test_unknown_policy():
args = ["--policy", "nonexistent"]
with pytest.raises(ValueError, match="Unknown policy"):
parse_policy_args(args)


def test_missing_policy_arg():
args = ["--policy"]
with pytest.raises(ValueError, match="--policy requires an argument"):
parse_policy_args(args)


def test_mixed_args():
args = ["torch", "--preview", "--policy", "stable", "--uv"]
# stable: p=F, u=F
# --preview: p=T
# Result: p=T, u=F
# cleaned: ["torch", "--uv"]
preview, unsupported, cleaned = parse_policy_args(args)
assert preview is True
assert unsupported is False
assert cleaned == ["torch", "--uv"]
68 changes: 68 additions & 0 deletions tests/test_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
from torchruntime.device_db import GPU
from torchruntime.platform_detection import AMD, INTEL, get_torch_platform, py_version


def test_preview_rocm_6_4_selection(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Navi 41", True)]

if py_version < (3, 9):
pytest.skip("Navi 4 requires Python 3.9+")

# Default: preview=False -> cpu
assert get_torch_platform(gpu_infos) == "cpu"
assert get_torch_platform(gpu_infos, preview=False) == "cpu"

# preview=True -> rocm6.4
assert get_torch_platform(gpu_infos, preview=True) == "rocm6.4"


def test_eol_rocm_5_2_selection(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Navi 10", True)]

assert get_torch_platform(gpu_infos) == "rocm5.2"

with pytest.raises(ValueError, match="considered End-of-Life"):
get_torch_platform(gpu_infos, unsupported=False)


def test_eol_rocm42_selection(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
# Ellesmere (e.g. RX 580)
gpu_infos = [GPU(AMD, "AMD", "67df", "Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]", True)]

# Default: unsupported=True -> rocm4.2
assert get_torch_platform(gpu_infos) == "rocm4.2"

# unsupported=False -> raises ValueError
with pytest.raises(ValueError, match="considered End-of-Life"):
get_torch_platform(gpu_infos, unsupported=False)


def test_eol_directml_selection(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows")
monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64")
gpu_infos = [GPU(AMD, "AMD", 0x1234, "Radeon", True)]

assert get_torch_platform(gpu_infos) == "directml"

with pytest.raises(ValueError, match="considered End-of-Life"):
get_torch_platform(gpu_infos, unsupported=False)


def test_eol_ipex_selection(monkeypatch):
monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux")
monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64")
monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 8))
gpu_infos = [GPU(INTEL, "Intel", 0x1234, "Iris", True)]

assert get_torch_platform(gpu_infos) == "ipex"

# unsupported=False -> raises ValueError
with pytest.raises(ValueError, match="considered End-of-Life"):
get_torch_platform(gpu_infos, unsupported=False)
30 changes: 25 additions & 5 deletions torchruntime/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .installer import install
from .utils.torch_test import test
from .utils import info
from .utils.args import parse_policy_args


def print_usage(entry_command: str):
Expand All @@ -16,6 +17,9 @@ def print_usage(entry_command: str):
Examples:
{entry_command} install
{entry_command} install --uv
{entry_command} install --preview
{entry_command} install --no-unsupported
{entry_command} install --policy stable
{entry_command} install torch==2.2.0 torchvision==0.17.0
{entry_command} install --uv torch>=2.0.0 torchaudio
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0
Expand All @@ -35,6 +39,9 @@ def print_usage(entry_command: str):

Options:
--uv Use uv instead of pip for installation
--preview Allow preview builds (e.g. ROCm 6.4)
--no-unsupported Forbid EOL/unsupported builds (e.g. DirectML / IPEX / Torch 1.x)
--policy <name> Set configuration policy (stable, compat, preview|nightly). Default: compat

Version specification formats (follows pip format):
package==2.1.0 Exact version
Expand Down Expand Up @@ -62,15 +69,28 @@ def main():

if command == "install":
args = sys.argv[2:] if len(sys.argv) > 2 else []
use_uv = "--uv" in args
# Remove --uv from args to get package list
package_versions = [arg for arg in args if arg != "--uv"] if args else None
install(package_versions, use_uv=use_uv)
try:
preview, unsupported, cleaned_args = parse_policy_args(args)
except ValueError as e:
print(f"Error: {e}")
return

use_uv = "--uv" in cleaned_args
# Remove --uv from package list
package_versions = [arg for arg in cleaned_args if arg != "--uv"]
install(package_versions, use_uv=use_uv, preview=preview, unsupported=unsupported)
elif command == "test":
subcommand = sys.argv[2] if len(sys.argv) > 2 else "all"
test(subcommand)
elif command == "info":
info()
args = sys.argv[2:] if len(sys.argv) > 2 else []
try:
preview, unsupported, _ = parse_policy_args(args)
except ValueError as e:
print(f"Error: {e}")
return
from .utils import info
info(preview=preview, unsupported=unsupported)
else:
print(f"Unknown command: {command}")
entry_path = sys.argv[0]
Expand Down
7 changes: 7 additions & 0 deletions torchruntime/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@
AMD = "1002"
NVIDIA = "10de"
INTEL = "8086"

POLICIES = {
"stable": (False, False), # preview=False, unsupported=False
"compat": (False, True), # preview=False, unsupported=True (Default)
"preview": (True, True), # preview=True, unsupported=True
"nightly": (True, True), # alias for preview
}
Loading