@@ -255,6 +255,7 @@ struct vk_device_struct {
255255 vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16;
256256 vk_pipeline pipeline_argsort_f32;
257257 vk_pipeline pipeline_sum_rows_f32;
258+ vk_pipeline pipeline_argmax_f32;
258259 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
259260 vk_pipeline pipeline_timestep_embedding_f32;
260261 vk_pipeline pipeline_pool2d_f32;
@@ -2203,6 +2204,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22032204
22042205 ggml_vk_create_pipeline (device, device->pipeline_argsort_f32 , " argsort_f32" , argsort_f32_len, argsort_f32_data, " main" , 2 , sizeof (vk_op_argsort_push_constants), {1024 , 1 , 1 }, {}, 1 );
22052206
2207+ ggml_vk_create_pipeline (device, device->pipeline_argmax_f32 , " argmax_f32" , argmax_f32_len, argmax_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
2208+
22062209 ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
22072210
22082211 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
@@ -5364,6 +5367,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53645367 return ctx->device ->pipeline_sum_rows_f32 ;
53655368 }
53665369 return nullptr ;
5370+ case GGML_OP_ARGMAX:
5371+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
5372+ return ctx->device ->pipeline_argmax_f32 ;
5373+ }
5374+ return nullptr ;
53675375 case GGML_OP_IM2COL:
53685376 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53695377 return ctx->device ->pipeline_im2col_f32 ;
@@ -5628,6 +5636,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56285636 case GGML_OP_RMS_NORM:
56295637 case GGML_OP_SOFT_MAX:
56305638 case GGML_OP_SUM_ROWS:
5639+ case GGML_OP_ARGMAX:
56315640 {
56325641 const uint32_t nr = ggml_nrows (src0);
56335642 if (nr > 262144 ) {
@@ -6240,6 +6249,10 @@ static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx,
62406249 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM_ROWS, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
62416250}
62426251
6252+ static void ggml_vk_argmax (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6253+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_ARGMAX, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
6254+ }
6255+
62436256static void ggml_vk_im2col (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
62446257 const int32_t s0 = dst->op_params [0 ];
62456258 const int32_t s1 = dst->op_params [1 ];
@@ -7131,6 +7144,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71317144 case GGML_OP_ARGSORT:
71327145 case GGML_OP_SUM:
71337146 case GGML_OP_SUM_ROWS:
7147+ case GGML_OP_ARGMAX:
71347148 case GGML_OP_IM2COL:
71357149 case GGML_OP_TIMESTEP_EMBEDDING:
71367150 case GGML_OP_POOL_2D:
@@ -7183,6 +7197,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71837197 case GGML_OP_ARGSORT:
71847198 case GGML_OP_SUM:
71857199 case GGML_OP_SUM_ROWS:
7200+ case GGML_OP_ARGMAX:
71867201 case GGML_OP_IM2COL:
71877202 case GGML_OP_TIMESTEP_EMBEDDING:
71887203 case GGML_OP_POOL_2D:
@@ -7310,6 +7325,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73107325 case GGML_OP_SUM_ROWS:
73117326 ggml_vk_sum_rows (ctx, compute_ctx, src0, node, dryrun);
73127327
7328+ break ;
7329+ case GGML_OP_ARGMAX:
7330+ ggml_vk_argmax (ctx, compute_ctx, src0, node, dryrun);
7331+
73137332 break ;
73147333 case GGML_OP_IM2COL:
73157334 ggml_vk_im2col (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7422,6 +7441,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74227441 case GGML_OP_ARGSORT:
74237442 case GGML_OP_SUM:
74247443 case GGML_OP_SUM_ROWS:
7444+ case GGML_OP_ARGMAX:
74257445 case GGML_OP_IM2COL:
74267446 case GGML_OP_TIMESTEP_EMBEDDING:
74277447 case GGML_OP_POOL_2D:
@@ -8353,6 +8373,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83538373 case GGML_OP_ARGSORT:
83548374 case GGML_OP_SUM:
83558375 case GGML_OP_SUM_ROWS:
8376+ case GGML_OP_ARGMAX:
83568377 case GGML_OP_IM2COL:
83578378 case GGML_OP_TIMESTEP_EMBEDDING:
83588379 case GGML_OP_POOL_2D:
@@ -8937,6 +8958,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89378958 tensor_clone = ggml_sum (ggml_ctx, src0_clone);
89388959 } else if (tensor->op == GGML_OP_SUM_ROWS) {
89398960 tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
8961+ } else if (tensor->op == GGML_OP_ARGMAX) {
8962+ tensor_clone = ggml_argmax (ggml_ctx, src0_clone);
89408963 } else if (tensor->op == GGML_OP_IM2COL) {
89418964 const int32_t s0 = tensor->op_params [0 ];
89428965 const int32_t s1 = tensor->op_params [1 ];
0 commit comments