33
44import logging
55from functools import wraps , lru_cache
6+ import os , sys
67from typing import Dict , List , Optional , Collection
78
89from parla import array
1516
1617logger = logging .getLogger (__name__ )
1718
19+ if 'cupy_backends' in sys .modules :
20+ # TODO: This should be dynamically configurable. That needs to be fixed upstream though.
21+ raise ImportError ("cupy must be imported after parla.gpu for per-thread default stream configuration to work properly." )
22+
23+ os .environ ["CUPY_CUDA_PER_THREAD_DEFAULT_STREAM" ] = "1"
24+ # numba responds to this env var even if it has already been imported.
25+ os .environ ["NUMBA_CUDA_PER_THREAD_DEFAULT_STREAM" ] = "1"
26+
1827try :
1928 import cupy
2029 import cupy .cuda
@@ -167,33 +176,19 @@ def get_array_module(self, a):
167176# Integration with parla.environments
168177
169178class _GPUStacksLocal (threading .local ):
170- _stream_stack : List [cupy .cuda .Stream ]
171179 _device_stack : List [cupy .cuda .Device ]
172180
173181 def __init__ (self ):
174182 super (_GPUStacksLocal , self ).__init__ ()
175- self ._stream_stack = []
176183 self ._device_stack = []
177184
178- def push_stream (self , stream ):
179- self ._stream_stack .append (stream )
180-
181- def pop_stream (self ) -> cupy .cuda .Stream :
182- return self ._stream_stack .pop ()
183-
184185 def push_device (self , dev ):
185186 self ._device_stack .append (dev )
186187
187188 def pop_device (self ) -> cupy .cuda .Device :
188189 return self ._device_stack .pop ()
189190
190191 @property
191- def stream (self ):
192- if self ._stream_stack :
193- return self ._stream_stack [- 1 ]
194- else :
195- return None
196- @property
197192 def device (self ):
198193 if self ._device_stack :
199194 return self ._device_stack [- 1 ]
@@ -213,30 +208,20 @@ def __init__(self, descriptor: "GPUComponent", env: TaskEnvironment):
213208 # Use a stack per thread per GPU component just in case.
214209 self ._stack = _GPUStacksLocal ()
215210
216- def _make_stream (self ):
217- with self .gpu .cupy_device :
218- return cupy .cuda .Stream (null = False , non_blocking = True )
219-
220211 def __enter__ (self ):
221212 _gpu_locals ._gpus = self .gpus
222213 dev = self .gpu .cupy_device
223214 dev .__enter__ ()
224215 self ._stack .push_device (dev )
225- stream = self ._make_stream ()
226- stream .__enter__ ()
227- self ._stack .push_stream (stream )
228216 return self
229217
230218 def __exit__ (self , exc_type , exc_val , exc_tb ):
231219 dev = self ._stack .device
232- stream = self ._stack .stream
233220 try :
234- stream .synchronize ()
235- stream .__exit__ (exc_type , exc_val , exc_tb )
221+ cupy .cuda .get_current_stream ().synchronize ()
236222 _gpu_locals ._gpus = None
237223 ret = dev .__exit__ (exc_type , exc_val , exc_tb )
238224 finally :
239- self ._stack .pop_stream ()
240225 self ._stack .pop_device ()
241226 return ret
242227
@@ -245,15 +230,15 @@ def initialize_thread(self) -> None:
245230 # Trigger cuBLAS/etc. initialization for this GPU in this thread.
246231 with cupy .cuda .Device (gpu .index ) as device :
247232 a = cupy .asarray ([2. ])
248- cupy .cuda .get_current_stream (). synchronize ()
249- with cupy . cuda . Stream ( False , True ) as stream :
250- cupy .asnumpy (cupy .sqrt (a ))
251- device .cublas_handle
252- device .cusolver_handle
253- device .cusolver_sp_handle
254- device .cusparse_handle
255- stream .synchronize ()
256- device .synchronize ()
233+ stream = cupy .cuda .get_current_stream ()
234+ stream . synchronize ()
235+ cupy .asnumpy (cupy .sqrt (a ))
236+ device .cublas_handle
237+ device .cusolver_handle
238+ device .cusolver_sp_handle
239+ device .cusparse_handle
240+ stream .synchronize ()
241+ device .synchronize ()
257242
258243class GPUComponent (EnvironmentComponentDescriptor ):
259244 """A single GPU CUDA component which configures the environment to use the specific GPU using a single
@@ -312,15 +297,15 @@ def initialize_thread(self) -> None:
312297 # Trigger cuBLAS/etc. initialization for this GPU in this thread.
313298 with cupy .cuda .Device (gpu .index ) as device :
314299 a = cupy .asarray ([2. ])
315- cupy .cuda .get_current_stream (). synchronize ()
316- with cupy . cuda . Stream ( False , True ) as stream :
317- cupy .asnumpy (cupy .sqrt (a ))
318- device .cublas_handle
319- device .cusolver_handle
320- device .cusolver_sp_handle
321- device .cusparse_handle
322- stream .synchronize ()
323- device .synchronize ()
300+ stream = cupy .cuda .get_current_stream ()
301+ stream . synchronize ()
302+ cupy .asnumpy (cupy .sqrt (a ))
303+ device .cublas_handle
304+ device .cusolver_handle
305+ device .cusolver_sp_handle
306+ device .cusparse_handle
307+ stream .synchronize ()
308+ device .synchronize ()
324309
325310
326311class MultiGPUComponent (EnvironmentComponentDescriptor ):
0 commit comments