Skip to content

Commit 16d85ea

Browse files
Better handle torch being imported by prestartup nodes. (#11383)
1 parent 5d9ad0c commit 16d85ea

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

main.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,38 @@
2323

2424
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
2525

26+
if os.name == "nt":
27+
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
28+
29+
if __name__ == "__main__":
30+
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
31+
if args.default_device is not None:
32+
default_dev = args.default_device
33+
devices = list(range(32))
34+
devices.remove(default_dev)
35+
devices.insert(0, default_dev)
36+
devices = ','.join(map(str, devices))
37+
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
38+
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
39+
40+
if args.cuda_device is not None:
41+
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
42+
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
43+
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
44+
logging.info("Set cuda device to: {}".format(args.cuda_device))
45+
46+
if args.oneapi_device_selector is not None:
47+
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
48+
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
49+
50+
if args.deterministic:
51+
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
52+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
53+
54+
import cuda_malloc
55+
if "rocm" in cuda_malloc.get_torch_version_noimport():
56+
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
57+
2658

2759
def handle_comfyui_manager_unavailable():
2860
if not args.windows_standalone_build:
@@ -137,40 +169,6 @@ def execute_script(script_path):
137169
import threading
138170
import gc
139171

140-
141-
if os.name == "nt":
142-
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
143-
144-
if __name__ == "__main__":
145-
os.environ['TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL'] = '1'
146-
if args.default_device is not None:
147-
default_dev = args.default_device
148-
devices = list(range(32))
149-
devices.remove(default_dev)
150-
devices.insert(0, default_dev)
151-
devices = ','.join(map(str, devices))
152-
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices)
153-
os.environ['HIP_VISIBLE_DEVICES'] = str(devices)
154-
155-
if args.cuda_device is not None:
156-
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
157-
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
158-
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = str(args.cuda_device)
159-
logging.info("Set cuda device to: {}".format(args.cuda_device))
160-
161-
if args.oneapi_device_selector is not None:
162-
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
163-
logging.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
164-
165-
if args.deterministic:
166-
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
167-
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
168-
169-
import cuda_malloc
170-
if "rocm" in cuda_malloc.get_torch_version_noimport():
171-
os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD
172-
173-
174172
if 'torch' in sys.modules:
175173
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
176174

0 commit comments

Comments
 (0)