@@ -182,7 +182,7 @@ def benchmarks(self) -> list[Benchmark]:
182182 GraphApiSinKernelGraph (self , runtime , with_graphs , num_kernels )
183183 )
184184
185- # Add ULLS benchmarks
185+ # Add ULLS benchmarks
186186 for runtime in list (RUNTIMES ):
187187 if runtime == RUNTIMES .SYCL :
188188 benches .append (
@@ -355,6 +355,46 @@ def createTorchMultiQueueBench(variant_name: str, **kwargs):
355355 ),
356356 ]
357357
358+ # Add LinearKernelSize benchmarks
359+ for runtime in filter (lambda x : x != RUNTIMES .UR , RUNTIMES ):
360+
361+ def createLinearKernelSizeBench (variant_name : str , ** kwargs ):
362+ return TorchLinearKernelSize (
363+ self ,
364+ runtime ,
365+ variant_name ,
366+ PROFILERS .TIMER ,
367+ ** kwargs ,
368+ )
369+
370+ benches += [
371+ createLinearKernelSizeBench (
372+ "array32" ,
373+ kernelBatchSize = 512 ,
374+ kernelSize = 32 ,
375+ ),
376+ createLinearKernelSizeBench (
377+ "array128" ,
378+ kernelBatchSize = 512 ,
379+ kernelSize = 128 ,
380+ ),
381+ createLinearKernelSizeBench (
382+ "array512" ,
383+ kernelBatchSize = 512 ,
384+ kernelSize = 512 ,
385+ ),
386+ createLinearKernelSizeBench (
387+ "array1024" ,
388+ kernelBatchSize = 512 ,
389+ kernelSize = 1024 ,
390+ ),
391+ createLinearKernelSizeBench (
392+ "array5120" ,
393+ kernelBatchSize = 512 ,
394+ kernelSize = 5120 ,
395+ ),
396+ ]
397+
358398 # Add UR-specific benchmarks
359399 benches += [
360400 # TODO: multithread_benchmark_ur fails with segfault
@@ -810,25 +850,31 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
810850 return [f"--{ k } ={ v } " for k , v in self ._rr_params .items ()]
811851
812852
813- class TorchMultiQueue (ComputeBenchmark ):
853+ class TorchBenchmark (ComputeBenchmark ):
814854 def __init__ (
815- self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
855+ self ,
856+ suite ,
857+ runtime : RUNTIMES ,
858+ bench_name : str ,
859+ variant_name : str ,
860+ profiler_type ,
861+ ** kwargs ,
816862 ):
817863 self ._variant_name = variant_name
818- self ._smq_params = kwargs
864+ self ._torch_params = kwargs
819865 self ._iterations_regular = 1000
820866 self ._iterations_trace = 10
821867 super ().__init__ (
822868 suite ,
823869 f"torch_benchmark_{ runtime .value } " ,
824- "KernelSubmitMultiQueue" ,
870+ bench_name ,
825871 runtime ,
826872 profiler_type ,
827873 )
828874
829875 def name (self ):
830876 ret = []
831- for k , v in self ._smq_params .items ():
877+ for k , v in self ._torch_params .items ():
832878 ret .append (f"{ k } { v } " )
833879 ret .sort ()
834880 return self ._bench_name + " " + ", " .join (ret )
@@ -848,10 +894,38 @@ def _supported_runtimes(self) -> list[RUNTIMES]:
848894 def _bin_args (self , run_trace : TracingType = TracingType .NONE ) -> list [str ]:
849895 iters = self ._get_iters (run_trace )
850896 return [f"--iterations={ iters } " ] + [
851- f"--{ k } ={ v } " for k , v in self ._smq_params .items ()
897+ f"--{ k } ={ v } " for k , v in self ._torch_params .items ()
852898 ]
853899
854900
901+ class TorchMultiQueue (TorchBenchmark ):
902+ def __init__ (
903+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
904+ ):
905+ super ().__init__ (
906+ suite ,
907+ runtime ,
908+ "KernelSubmitMultiQueue" ,
909+ variant_name ,
910+ profiler_type ,
911+ ** kwargs ,
912+ )
913+
914+
915+ class TorchLinearKernelSize (TorchBenchmark ):
916+ def __init__ (
917+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
918+ ):
919+ super ().__init__ (
920+ suite ,
921+ runtime ,
922+ "KernelSubmitLinearKernelSize" ,
923+ variant_name ,
924+ profiler_type ,
925+ ** kwargs ,
926+ )
927+
928+
855929class QueueInOrderMemcpy (ComputeBenchmark ):
856930 def __init__ (self , bench , isCopyOnly , source , destination , size , profiler_type ):
857931 self ._is_copy_only = isCopyOnly
0 commit comments