-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathsetup.py
More file actions
87 lines (74 loc) · 2.42 KB
/
setup.py
File metadata and controls
87 lines (74 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2025 Parallel Software and Systems Group, University of Maryland.
# See the top-level LICENSE file for details.
#
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
from setuptools import setup, find_packages
import pathlib
import sys
# === Dependency checks ===
missing = []
try:
import torch
except ImportError:
missing.append("torch")
try:
import mpi4py
except ImportError:
missing.append("mpi4py")
if missing:
print(f"ERROR: Missing required packages: {', '.join(missing)}")
print("Please install them manually before proceeding:")
sys.exit(1)
from torch.utils.cpp_extension import BuildExtension, CppExtension
import mpi4py
# === Optional ninja ===
try:
import ninja
USE_NINJA = True
except ImportError:
USE_NINJA = False
print(
"\033[93m[Warning] 'ninja' is not installed. "
"Falling back to slower build system. "
"For faster builds, run: pip install ninja\033[0m"
)
# === Paths ===
srcpath = pathlib.Path(__file__).parent / "csrc"
sources = list(map(str, srcpath.glob("*.cpp"))) + list(map(str, srcpath.glob("*.cu")))
extra_include_paths = [mpi4py.get_include(), str(srcpath)]
# === Long description ===
long_description = ""
readme_path = pathlib.Path(__file__).parent / "README.md"
if readme_path.exists():
long_description = readme_path.read_text(encoding="utf-8")
class BuildExtensionWithNinja(BuildExtension):
def __init__(self, *args, **kwargs):
super().__init__(*args, use_ninja=True, **kwargs)
# === Setup ===
setup(
name="pccl",
version="0.1.0",
packages=find_packages(where="."),
package_dir={"pccl": "pccl"},
ext_modules=[
CppExtension(
name="pccl_mpi_extension",
sources=sources,
extra_compile_args=["-O3", f"-I{mpi4py.get_include()}"],
extra_include_paths=extra_include_paths,
)
],
cmdclass={"build_ext": BuildExtensionWithNinja},
install_requires=[], # torch and mpi4py are manually checked above
author="Siddharth Singh, Abhinav Bhatele",
description="Performant Collective Communication Library",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/your-org/pccl",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires=">=3.7",
)