@@ -5358,6 +5358,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53585358 return ctx->device ->pipeline_argsort_f32 ;
53595359 }
53605360 return nullptr ;
5361+ case GGML_OP_SUM:
53615362 case GGML_OP_SUM_ROWS:
53625363 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53635364 return ctx->device ->pipeline_sum_rows_f32 ;
@@ -5637,6 +5638,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56375638 elements = { nr, 1 , 1 };
56385639 }
56395640 } break ;
5641+ case GGML_OP_SUM:
5642+ // We use GGML_OP_SUM_ROWS with 1 row.
5643+ elements = { 1 , 1 , 1 };
5644+ break ;
56405645 case GGML_OP_GROUP_NORM:
56415646 {
56425647 const uint32_t num_groups = dst->op_params [0 ];
@@ -6227,6 +6232,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
62276232 }, dryrun);
62286233}
62296234
6235+ static void ggml_vk_sum (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6236+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6237+ }
6238+
62306239static void ggml_vk_sum_rows (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
62316240 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_SUM_ROWS, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
62326241}
@@ -7120,6 +7129,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71207129 case GGML_OP_MUL_MAT:
71217130 case GGML_OP_MUL_MAT_ID:
71227131 case GGML_OP_ARGSORT:
7132+ case GGML_OP_SUM:
71237133 case GGML_OP_SUM_ROWS:
71247134 case GGML_OP_IM2COL:
71257135 case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7171,6 +7181,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71717181 case GGML_OP_SOFT_MAX:
71727182 case GGML_OP_ROPE:
71737183 case GGML_OP_ARGSORT:
7184+ case GGML_OP_SUM:
71747185 case GGML_OP_SUM_ROWS:
71757186 case GGML_OP_IM2COL:
71767187 case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7291,6 +7302,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72917302 case GGML_OP_ARGSORT:
72927303 ggml_vk_argsort (ctx, compute_ctx, src0, node, dryrun);
72937304
7305+ break ;
7306+ case GGML_OP_SUM:
7307+ ggml_vk_sum (ctx, compute_ctx, src0, node, dryrun);
7308+
72947309 break ;
72957310 case GGML_OP_SUM_ROWS:
72967311 ggml_vk_sum_rows (ctx, compute_ctx, src0, node, dryrun);
@@ -7405,6 +7420,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74057420 case GGML_OP_TRANSPOSE:
74067421 case GGML_OP_NONE:
74077422 case GGML_OP_ARGSORT:
7423+ case GGML_OP_SUM:
74087424 case GGML_OP_SUM_ROWS:
74097425 case GGML_OP_IM2COL:
74107426 case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8335,6 +8351,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83358351 case GGML_OP_DIAG_MASK_INF:
83368352 case GGML_OP_SOFT_MAX:
83378353 case GGML_OP_ARGSORT:
8354+ case GGML_OP_SUM:
83388355 case GGML_OP_SUM_ROWS:
83398356 case GGML_OP_IM2COL:
83408357 case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8916,6 +8933,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89168933 tensor_clone = ggml_get_rows (ggml_ctx, src0_clone, src1_clone);
89178934 } else if (tensor->op == GGML_OP_ARGSORT) {
89188935 tensor_clone = ggml_argsort (ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params );
8936+ } else if (tensor->op == GGML_OP_SUM) {
8937+ tensor_clone = ggml_sum (ggml_ctx, src0_clone);
89198938 } else if (tensor->op == GGML_OP_SUM_ROWS) {
89208939 tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
89218940 } else if (tensor->op == GGML_OP_IM2COL) {
0 commit comments