Skip to content

Commit b3064f8

Browse files
[Bench] Add TorchLinearKernelSize benchmark
1 parent 38f555f commit b3064f8

File tree

2 files changed

+91
-7
lines changed

2 files changed

+91
-7
lines changed

devops/scripts/benchmarks/benches/compute.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def benchmarks(self) -> list[Benchmark]:
182182
GraphApiSinKernelGraph(self, runtime, with_graphs, num_kernels)
183183
)
184184

185-
# Add ULLS benchmarks
185+
# Add ULLS benchmarks
186186
for runtime in list(RUNTIMES):
187187
if runtime == RUNTIMES.SYCL:
188188
benches.append(
@@ -355,6 +355,46 @@ def createTorchMultiQueueBench(variant_name: str, **kwargs):
355355
),
356356
]
357357

358+
# Add LinearKernelSize benchmarks
359+
for runtime in filter(lambda x: x != RUNTIMES.UR, RUNTIMES):
360+
361+
def createLinearKernelSizeBench(variant_name: str, **kwargs):
362+
return TorchLinearKernelSize(
363+
self,
364+
runtime,
365+
variant_name,
366+
PROFILERS.TIMER,
367+
**kwargs,
368+
)
369+
370+
benches += [
371+
createLinearKernelSizeBench(
372+
"array32",
373+
kernelBatchSize=512,
374+
kernelSize=32,
375+
),
376+
createLinearKernelSizeBench(
377+
"array128",
378+
kernelBatchSize=512,
379+
kernelSize=128,
380+
),
381+
createLinearKernelSizeBench(
382+
"array512",
383+
kernelBatchSize=512,
384+
kernelSize=512,
385+
),
386+
createLinearKernelSizeBench(
387+
"array1024",
388+
kernelBatchSize=512,
389+
kernelSize=1024,
390+
),
391+
createLinearKernelSizeBench(
392+
"array5120",
393+
kernelBatchSize=512,
394+
kernelSize=5120,
395+
),
396+
]
397+
358398
# Add UR-specific benchmarks
359399
benches += [
360400
# TODO: multithread_benchmark_ur fails with segfault
@@ -810,25 +850,25 @@ def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
810850
return [f"--{k}={v}" for k, v in self._rr_params.items()]
811851

812852

813-
class TorchMultiQueue(ComputeBenchmark):
853+
class TorchBenchmark(ComputeBenchmark):
814854
def __init__(
815-
self, suite, runtime: RUNTIMES, variant_name: str, profiler_type, **kwargs
855+
self, suite, runtime: RUNTIMES, bench_name: str, variant_name: str, profiler_type, **kwargs
816856
):
817857
self._variant_name = variant_name
818-
self._smq_params = kwargs
858+
self._torch_params = kwargs
819859
self._iterations_regular = 1000
820860
self._iterations_trace = 10
821861
super().__init__(
822862
suite,
823863
f"torch_benchmark_{runtime.value}",
824-
"KernelSubmitMultiQueue",
864+
bench_name,
825865
runtime,
826866
profiler_type,
827867
)
828868

829869
def name(self):
830870
ret = []
831-
for k, v in self._smq_params.items():
871+
for k, v in self._torch_params.items():
832872
ret.append(f"{k} {v}")
833873
ret.sort()
834874
return self._bench_name + " " + ", ".join(ret)
@@ -848,10 +888,38 @@ def _supported_runtimes(self) -> list[RUNTIMES]:
848888
def _bin_args(self, run_trace: TracingType = TracingType.NONE) -> list[str]:
849889
iters = self._get_iters(run_trace)
850890
return [f"--iterations={iters}"] + [
851-
f"--{k}={v}" for k, v in self._smq_params.items()
891+
f"--{k}={v}" for k, v in self._torch_params.items()
852892
]
853893

854894

895+
class TorchMultiQueue(TorchBenchmark):
896+
def __init__(
897+
self, suite, runtime: RUNTIMES, variant_name: str, profiler_type, **kwargs
898+
):
899+
super().__init__(
900+
suite,
901+
runtime,
902+
"KernelSubmitMultiQueue",
903+
variant_name,
904+
profiler_type,
905+
**kwargs,
906+
)
907+
908+
909+
class TorchLinearKernelSize(TorchBenchmark):
910+
def __init__(
911+
self, suite, runtime: RUNTIMES, variant_name: str, profiler_type, **kwargs
912+
):
913+
super().__init__(
914+
suite,
915+
runtime,
916+
"KernelSubmitLinearKernelSize",
917+
variant_name,
918+
profiler_type,
919+
**kwargs,
920+
)
921+
922+
855923
class QueueInOrderMemcpy(ComputeBenchmark):
856924
def __init__(self, bench, isCopyOnly, source, destination, size, profiler_type):
857925
self._is_copy_only = isCopyOnly

devops/scripts/benchmarks/tests/test_integration.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,20 +194,36 @@ def test_torch_l0(self):
194194
"KernelSubmitMultiQueue large",
195195
{"pytorch", "L0"},
196196
)
197+
self._checkCase(
198+
"torch_benchmark_l0 kernelBatchSize 512, kernelSize 32",
199+
"KernelSubmitLinearKernelSize array32",
200+
{"pytorch", "L0"},
201+
)
197202

198203
def test_torch_sycl(self):
199204
self._checkCase(
200205
"torch_benchmark_sycl kernelsPerQueue 10, workgroupCount 512, workgroupSize 256",
201206
"KernelSubmitMultiQueue medium",
202207
{"pytorch", "SYCL"},
203208
)
209+
self._checkCase(
210+
"torch_benchmark_sycl kernelBatchSize 512, kernelSize 5120",
211+
"KernelSubmitLinearKernelSize array5120",
212+
{"pytorch", "SYCL"},
213+
)
204214

205215
def test_torch_syclpreview(self):
206216
self._checkCase(
207217
"torch_benchmark_syclpreview kernelsPerQueue 4, workgroupCount 256, workgroupSize 124",
208218
"KernelSubmitMultiQueue small",
209219
{"pytorch", "SYCL"},
210220
)
221+
self._checkCase(
222+
"torch_benchmark_syclpreview kernelBatchSize 512, kernelSize 512",
223+
"KernelSubmitLinearKernelSize array512",
224+
{"pytorch", "SYCL"},
225+
)
226+
211227

212228
if __name__ == "__main__":
213229
unittest.main()

0 commit comments

Comments
 (0)