@@ -182,7 +182,7 @@ def benchmarks(self) -> list[Benchmark]:
182182 GraphApiSinKernelGraph (self , runtime , with_graphs , num_kernels )
183183 )
184184
185- # Add ULLS benchmarks
185+ # Add ULLS benchmarks
186186 for runtime in list (RUNTIMES ):
187187 if runtime == RUNTIMES .SYCL :
188188 benches .append (
@@ -355,6 +355,46 @@ def createTorchMultiQueueBench(variant_name: str, **kwargs):
355355 ),
356356 ]
357357
358+ # Add LinearKernelSize benchmarks
359+ for runtime in filter (lambda x : x != RUNTIMES .UR , RUNTIMES ):
360+
361+ def createLinearKernelSizeBench (variant_name : str , ** kwargs ):
362+ return TorchLinearKernelSize (
363+ self ,
364+ runtime ,
365+ variant_name ,
366+ PROFILERS .TIMER ,
367+ ** kwargs ,
368+ )
369+
370+ benches += [
371+ createLinearKernelSizeBench (
372+ "array32" ,
373+ kernelBatchSize = 512 ,
374+ kernelSize = 32 ,
375+ ),
376+ createLinearKernelSizeBench (
377+ "array128" ,
378+ kernelBatchSize = 512 ,
379+ kernelSize = 128 ,
380+ ),
381+ createLinearKernelSizeBench (
382+ "array512" ,
383+ kernelBatchSize = 512 ,
384+ kernelSize = 512 ,
385+ ),
386+ createLinearKernelSizeBench (
387+ "array1024" ,
388+ kernelBatchSize = 512 ,
389+ kernelSize = 1024 ,
390+ ),
391+ createLinearKernelSizeBench (
392+ "array5120" ,
393+ kernelBatchSize = 512 ,
394+ kernelSize = 5120 ,
395+ ),
396+ ]
397+
358398 # Add UR-specific benchmarks
359399 benches += [
360400 # TODO: multithread_benchmark_ur fails with segfault
@@ -810,25 +850,25 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
810850 return [f"--{ k } ={ v } " for k , v in self ._rr_params .items ()]
811851
812852
813- class TorchMultiQueue (ComputeBenchmark ):
853+ class TorchBenchmark (ComputeBenchmark ):
814854 def __init__ (
815- self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
855+ self , suite , runtime : RUNTIMES , bench_name : str , variant_name : str , profiler_type , ** kwargs
816856 ):
817857 self ._variant_name = variant_name
818- self ._smq_params = kwargs
858+ self ._torch_params = kwargs
819859 self ._iterations_regular = 1000
820860 self ._iterations_trace = 10
821861 super ().__init__ (
822862 suite ,
823863 f"torch_benchmark_{ runtime .value } " ,
824- "KernelSubmitMultiQueue" ,
864+ bench_name ,
825865 runtime ,
826866 profiler_type ,
827867 )
828868
829869 def name (self ):
830870 ret = []
831- for k , v in self ._smq_params .items ():
871+ for k , v in self ._torch_params .items ():
832872 ret .append (f"{ k } { v } " )
833873 ret .sort ()
834874 return self ._bench_name + " " + ", " .join (ret )
@@ -848,10 +888,38 @@ def _supported_runtimes(self) -> list[RUNTIMES]:
848888 def _bin_args (self , run_trace : TracingType = TracingType .NONE ) -> list [str ]:
849889 iters = self ._get_iters (run_trace )
850890 return [f"--iterations={ iters } " ] + [
851- f"--{ k } ={ v } " for k , v in self ._smq_params .items ()
891+ f"--{ k } ={ v } " for k , v in self ._torch_params .items ()
852892 ]
853893
854894
895+ class TorchMultiQueue (TorchBenchmark ):
896+ def __init__ (
897+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
898+ ):
899+ super ().__init__ (
900+ suite ,
901+ runtime ,
902+ "KernelSubmitMultiQueue" ,
903+ variant_name ,
904+ profiler_type ,
905+ ** kwargs ,
906+ )
907+
908+
909+ class TorchLinearKernelSize (TorchBenchmark ):
910+ def __init__ (
911+ self , suite , runtime : RUNTIMES , variant_name : str , profiler_type , ** kwargs
912+ ):
913+ super ().__init__ (
914+ suite ,
915+ runtime ,
916+ "KernelSubmitLinearKernelSize" ,
917+ variant_name ,
918+ profiler_type ,
919+ ** kwargs ,
920+ )
921+
922+
855923class QueueInOrderMemcpy (ComputeBenchmark ):
856924 def __init__ (self , bench , isCopyOnly , source , destination , size , profiler_type ):
857925 self ._is_copy_only = isCopyOnly
0 commit comments