@@ -182,7 +182,7 @@ def benchmarks(self) -> list[Benchmark]:
182182 GraphApiSinKernelGraph (self , runtime , with_graphs , num_kernels )
183183 )
184184
185- # Add ULLS benchmarks
185+ # Add ULLS benchmarks
186186 for runtime in list (RUNTIMES ):
187187 if runtime == RUNTIMES .SYCL :
188188 benches .append (
@@ -355,6 +355,39 @@ def createTorchMultiQueueBench(variant_name: str, **kwargs):
355355 ),
356356 ]
357357
358+ # Add TorchSlmSize benchmarks
359+ for runtime in filter (lambda x : x != RUNTIMES .UR , RUNTIMES ):
360+
361+ def createTorchSlmSizeBench (variant_name : str , ** kwargs ):
362+ return TorchSlmSize (
363+ self ,
364+ runtime ,
365+ variant_name ,
366+ PROFILERS .TIMER ,
367+ ** kwargs ,
368+ )
369+
370+ benches += [
371+ createTorchSlmSizeBench (
372+ "small" ,
373+ batchSize = 512 ,
374+ slmNum = 1 ,
375+ warmupIterations = 1 ,
376+ ),
377+ createTorchSlmSizeBench (
378+ "medium" ,
379+ batchSize = 512 ,
380+ slmNum = 1024 ,
381+ warmupIterations = 1 ,
382+ ),
383+ createTorchSlmSizeBench (
384+ "max" ,
385+ batchSize = 512 ,
386+ slmNum = - 1 ,
387+ warmupIterations = 1 ,
388+ ),
389+ ]
390+
358391 # Add UR-specific benchmarks
359392 benches += [
360393 # TODO: multithread_benchmark_ur fails with segfault
@@ -810,25 +843,31 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
810843 return [f"--{ k } ={ v } " for k , v in self ._rr_params .items ()]
811844
812845
813- class TorchMultiQueue (ComputeBenchmark ):
846+ class TorchBenchmark (ComputeBenchmark ):
814847 def __init__ (
815- self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
848+ self ,
849+ suite ,
850+ runtime : RUNTIMES ,
851+ bench_name : str ,
852+ variant_name : str ,
853+ profiler_type ,
854+ ** kwargs ,
816855 ):
817856 self ._variant_name = variant_name
818- self ._smq_params = kwargs
857+ self ._torch_params = kwargs
819858 self ._iterations_regular = 1000
820859 self ._iterations_trace = 10
821860 super ().__init__ (
822861 suite ,
823862 f"torch_benchmark_{ runtime .value } " ,
824- "KernelSubmitMultiQueue" ,
863+ bench_name ,
825864 runtime ,
826865 profiler_type ,
827866 )
828867
829868 def name (self ):
830869 ret = []
831- for k , v in self ._smq_params .items ():
870+ for k , v in self ._torch_params .items ():
832871 ret .append (f"{ k } { v } " )
833872 ret .sort ()
834873 return self ._bench_name + " " + ", " .join (ret )
@@ -848,10 +887,38 @@ def _supported_runtimes(self) -> list[RUNTIMES]:
848887 def _bin_args (self , run_trace : TracingType = TracingType .NONE ) -> list [str ]:
849888 iters = self ._get_iters (run_trace )
850889 return [f"--iterations={ iters } " ] + [
851- f"--{ k } ={ v } " for k , v in self ._smq_params .items ()
890+ f"--{ k } ={ v } " for k , v in self ._torch_params .items ()
852891 ]
853892
854893
894+ class TorchMultiQueue (TorchBenchmark ):
895+ def __init__ (
896+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
897+ ):
898+ super ().__init__ (
899+ suite ,
900+ runtime ,
901+ "KernelSubmitMultiQueue" ,
902+ variant_name ,
903+ profiler_type ,
904+ ** kwargs ,
905+ )
906+
907+
908+ class TorchSlmSize (TorchBenchmark ):
909+ def __init__ (
910+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
911+ ):
912+ super ().__init__ (
913+ suite ,
914+ runtime ,
915+ "KernelSubmitSlmSize" ,
916+ variant_name ,
917+ profiler_type ,
918+ ** kwargs ,
919+ )
920+
921+
855922class QueueInOrderMemcpy (ComputeBenchmark ):
856923 def __init__ (self , bench , isCopyOnly , source , destination , size , profiler_type ):
857924 self ._is_copy_only = isCopyOnly
0 commit comments