Skip to content

Commit 9de61b5

Browse files
vulkan: support GGML_OP_SUM
1 parent dca13cb commit 9de61b5

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5358,6 +5358,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53585358
return ctx->device->pipeline_argsort_f32;
53595359
}
53605360
return nullptr;
5361+
case GGML_OP_SUM:
53615362
case GGML_OP_SUM_ROWS:
53625363
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53635364
return ctx->device->pipeline_sum_rows_f32;
@@ -5637,6 +5638,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56375638
elements = { nr, 1, 1 };
56385639
}
56395640
} break;
5641+
case GGML_OP_SUM:
5642+
// We use GGML_OP_SUM_ROWS with 1 row.
5643+
elements = { 1, 1, 1 };
5644+
break;
56405645
case GGML_OP_GROUP_NORM:
56415646
{
56425647
const uint32_t num_groups = dst->op_params[0];
@@ -6227,6 +6232,10 @@ static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, c
62276232
}, dryrun);
62286233
}
62296234

6235+
static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6236+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6237+
}
6238+
62306239
static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
62316240
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
62326241
}
@@ -7120,6 +7129,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71207129
case GGML_OP_MUL_MAT:
71217130
case GGML_OP_MUL_MAT_ID:
71227131
case GGML_OP_ARGSORT:
7132+
case GGML_OP_SUM:
71237133
case GGML_OP_SUM_ROWS:
71247134
case GGML_OP_IM2COL:
71257135
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7171,6 +7181,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71717181
case GGML_OP_SOFT_MAX:
71727182
case GGML_OP_ROPE:
71737183
case GGML_OP_ARGSORT:
7184+
case GGML_OP_SUM:
71747185
case GGML_OP_SUM_ROWS:
71757186
case GGML_OP_IM2COL:
71767187
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -7291,6 +7302,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72917302
case GGML_OP_ARGSORT:
72927303
ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun);
72937304

7305+
break;
7306+
case GGML_OP_SUM:
7307+
ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun);
7308+
72947309
break;
72957310
case GGML_OP_SUM_ROWS:
72967311
ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun);
@@ -7405,6 +7420,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74057420
case GGML_OP_TRANSPOSE:
74067421
case GGML_OP_NONE:
74077422
case GGML_OP_ARGSORT:
7423+
case GGML_OP_SUM:
74087424
case GGML_OP_SUM_ROWS:
74097425
case GGML_OP_IM2COL:
74107426
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8335,6 +8351,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83358351
case GGML_OP_DIAG_MASK_INF:
83368352
case GGML_OP_SOFT_MAX:
83378353
case GGML_OP_ARGSORT:
8354+
case GGML_OP_SUM:
83388355
case GGML_OP_SUM_ROWS:
83398356
case GGML_OP_IM2COL:
83408357
case GGML_OP_TIMESTEP_EMBEDDING:
@@ -8916,6 +8933,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89168933
tensor_clone = ggml_get_rows(ggml_ctx, src0_clone, src1_clone);
89178934
} else if (tensor->op == GGML_OP_ARGSORT) {
89188935
tensor_clone = ggml_argsort(ggml_ctx, src0_clone, (ggml_sort_order) *(int *)tensor->op_params);
8936+
} else if (tensor->op == GGML_OP_SUM) {
8937+
tensor_clone = ggml_sum(ggml_ctx, src0_clone);
89198938
} else if (tensor->op == GGML_OP_SUM_ROWS) {
89208939
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
89218940
} else if (tensor->op == GGML_OP_IM2COL) {

0 commit comments

Comments
 (0)