forked from nv-tlabs/vipe
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathsetup.py
More file actions
56 lines (44 loc) · 1.79 KB
/
setup.py
File metadata and controls
56 lines (44 loc) · 1.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import re
from setuptools import find_packages, setup
try:
import torch
import torch.version
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
torch_version = torch.version.__version__.split(".")[:2]
cuda_version = torch.version.cuda
# This will be e.g. "+pt23cu121"
assert cuda_version is not None, "Pytorch CUDA is required for this installation."
version_suffix = f"+pt{torch_version[0]}{torch_version[1]}cu{cuda_version.replace('.', '')}"
except ImportError:
raise ValueError("Pytorch not found, please install it first.")
PACKAGE_NAME = "vipe"
# Avoid directly importing the package
with open(f"{PACKAGE_NAME}/__init__.py", "r") as fh:
__version__ = re.findall(r"__version__ = \"(.*?)\"", fh.read())[0]
__version__ += version_suffix
coder_finder_path = f"{PACKAGE_NAME}/ext/specs.py"
code_finder_namespace = {"__file__": coder_finder_path}
with open(coder_finder_path, "r") as fh:
exec(fh.read(), code_finder_namespace)
get_sources = code_finder_namespace["get_sources"]
get_cpp_flags = code_finder_namespace["get_cpp_flags"]
get_cuda_flags = code_finder_namespace["get_cuda_flags"]
# Setup CUDA_HOME for conda environment for consistency
if "CONDA_PREFIX" in os.environ:
conda_nvcc_path = os.path.join(os.environ["CONDA_PREFIX"], "bin", "nvcc")
if os.path.exists(conda_nvcc_path):
os.environ["PYTORCH_NVCC"] = conda_nvcc_path
packages = find_packages()
setup(
packages=packages,
version=__version__,
ext_modules=[
CUDAExtension(
f"{PACKAGE_NAME}_ext",
sources=get_sources(), # type: ignore
extra_compile_args={"cxx": get_cpp_flags(), "nvcc": get_cuda_flags()}, # type: ignore
)
],
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
)