Skip to content

Commit 7a6cb87

Browse files
vulkan: implement GGML_OP_COUNT_EQUAL
1 parent eacca58 commit 7a6cb87

File tree

4 files changed

+61
-2
lines changed

4 files changed

+61
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,7 @@ struct vk_device_struct {
257257
vk_pipeline pipeline_argsort_f32;
258258
vk_pipeline pipeline_sum_rows_f32;
259259
vk_pipeline pipeline_argmax_f32;
260+
vk_pipeline pipeline_count_equal_i32;
260261
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
261262
vk_pipeline pipeline_timestep_embedding_f32;
262263
vk_pipeline pipeline_pool2d_f32;
@@ -2211,6 +2212,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22112212

22122213
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
22132214

2215+
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
2216+
22142217
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
22152218
if (device->float_controls_rte_fp16) {
22162219
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
@@ -5380,6 +5383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53805383
return ctx->device->pipeline_argmax_f32;
53815384
}
53825385
return nullptr;
5386+
case GGML_OP_COUNT_EQUAL:
5387+
if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5388+
return ctx->device->pipeline_count_equal_i32;
5389+
}
5390+
return nullptr;
53835391
case GGML_OP_IM2COL:
53845392
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53855393
return ctx->device->pipeline_im2col_f32;
@@ -6278,6 +6286,11 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
62786286
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], 0, 0.0f, 0.0f }, dryrun);
62796287
}
62806288

6289+
static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6290+
ggml_backend_tensor_memset(dst, 0, 0, ggml_nbytes(dst));
6291+
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6292+
}
6293+
62816294
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
62826295
const int32_t s0 = dst->op_params[0];
62836296
const int32_t s1 = dst->op_params[1];
@@ -7171,6 +7184,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71717184
case GGML_OP_SUM:
71727185
case GGML_OP_SUM_ROWS:
71737186
case GGML_OP_ARGMAX:
7187+
case GGML_OP_COUNT_EQUAL:
71747188
case GGML_OP_IM2COL:
71757189
case GGML_OP_TIMESTEP_EMBEDDING:
71767190
case GGML_OP_POOL_2D:
@@ -7225,6 +7239,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72257239
case GGML_OP_SUM:
72267240
case GGML_OP_SUM_ROWS:
72277241
case GGML_OP_ARGMAX:
7242+
case GGML_OP_COUNT_EQUAL:
72287243
case GGML_OP_IM2COL:
72297244
case GGML_OP_TIMESTEP_EMBEDDING:
72307245
case GGML_OP_POOL_2D:
@@ -7360,6 +7375,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73607375
case GGML_OP_ARGMAX:
73617376
ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun);
73627377

7378+
break;
7379+
case GGML_OP_COUNT_EQUAL:
7380+
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun);
7381+
73637382
break;
73647383
case GGML_OP_IM2COL:
73657384
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7474,6 +7493,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74747493
case GGML_OP_SUM:
74757494
case GGML_OP_SUM_ROWS:
74767495
case GGML_OP_ARGMAX:
7496+
case GGML_OP_COUNT_EQUAL:
74777497
case GGML_OP_IM2COL:
74787498
case GGML_OP_TIMESTEP_EMBEDDING:
74797499
case GGML_OP_POOL_2D:
@@ -8407,6 +8427,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84078427
case GGML_OP_SUM:
84088428
case GGML_OP_SUM_ROWS:
84098429
case GGML_OP_ARGMAX:
8430+
case GGML_OP_COUNT_EQUAL:
84108431
case GGML_OP_IM2COL:
84118432
case GGML_OP_TIMESTEP_EMBEDDING:
84128433
case GGML_OP_POOL_2D:
@@ -8995,6 +9016,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89959016
tensor_clone = ggml_sum_rows(ggml_ctx, src0_clone);
89969017
} else if (tensor->op == GGML_OP_ARGMAX) {
89979018
tensor_clone = ggml_argmax(ggml_ctx, src0_clone);
9019+
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
9020+
tensor_clone = ggml_count_equal(ggml_ctx, src0_clone, src1_clone);
89989021
} else if (tensor->op == GGML_OP_IM2COL) {
89999022
const int32_t s0 = tensor->op_params[0];
90009023
const int32_t s1 = tensor->op_params[1];
@@ -9114,6 +9137,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
91149137
} else if (tensor->type == GGML_TYPE_I32) {
91159138
correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
91169139
result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
9140+
} else if (tensor->type == GGML_TYPE_I64) {
9141+
correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]);
9142+
result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]);
91179143
} else {
91189144
std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl;
91199145
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#version 450
2+
3+
#extension GL_EXT_control_flow_attributes : enable
4+
5+
#include "types.comp"
6+
#include "generic_head.comp"
7+
8+
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
11+
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
12+
layout (binding = 2) buffer D {D_TYPE data_d[];};
13+
14+
const uint CHUNK_SIZE = 512;
15+
16+
void main() {
17+
const uint base = gl_WorkGroupID.x * CHUNK_SIZE;
18+
const uint col = gl_LocalInvocationID.x;
19+
20+
uint count = 0;
21+
[[unroll]]
22+
for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) {
23+
const uint idx = base + i + col;
24+
if (idx >= p.KX) {
25+
break;
26+
}
27+
count += uint(data_a[idx] == data_b[idx]);
28+
}
29+
30+
atomicAdd(data_d[0], D_TYPE(count));
31+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ void process_shaders() {
505505

506506
string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}}));
507507
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
508+
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
508509

509510
string_to_spv("im2col_f32", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
510511
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));

tests/test-backend-ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ struct test_count_equal : public test_case {
12541254
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
12551255
ggml_set_name(b, "b");
12561256

1257-
ggml_tensor * b_argmax = ggml_argmax(ctx, a);
1257+
ggml_tensor * b_argmax = ggml_argmax(ctx, b);
12581258
ggml_set_name(b_argmax, "b_argmax");
12591259

12601260
ggml_tensor * out = ggml_count_equal(ctx, a_argmax, b_argmax);
@@ -3861,7 +3861,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
38613861
test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
38623862
test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
38633863

3864-
test_cases.emplace_back(new test_count_equal());
3864+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 500, 1, 1}));
3865+
test_cases.emplace_back(new test_count_equal(GGML_TYPE_F32, {4, 5000, 1, 1}));
38653866

38663867
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
38673868
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));

0 commit comments

Comments
 (0)