Skip to content

Commit eacca58

Browse files
vulkan: implement GGML_OP_SUB
1 parent f354a78 commit eacca58

File tree

4 files changed

+68
-1
lines changed

4 files changed

+68
-1
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ struct vk_device_struct {
222222
vk_pipeline pipeline_acc_f32;
223223
vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
224224
vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
225+
vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
225226
vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
226227
vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
227228
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -2148,6 +2149,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21482149

21492150
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
21502151

2152+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2153+
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
21512154
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
21522155
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
21532156
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -5192,6 +5195,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
51925195
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
51935196
}
51945197
return nullptr;
5198+
case GGML_OP_SUB:
5199+
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5200+
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
5201+
}
5202+
return nullptr;
51955203
case GGML_OP_MUL:
51965204
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
51975205
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
@@ -5412,6 +5420,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
54125420
case GGML_OP_CPY:
54135421
case GGML_OP_GET_ROWS:
54145422
case GGML_OP_ADD:
5423+
case GGML_OP_SUB:
54155424
case GGML_OP_MUL:
54165425
case GGML_OP_DIV:
54175426
case GGML_OP_CONCAT:
@@ -5697,6 +5706,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56975706
elements = { N * OC * OH * OW, 1, 1};
56985707
} break;
56995708
case GGML_OP_ADD:
5709+
case GGML_OP_SUB:
57005710
case GGML_OP_DIV:
57015711
case GGML_OP_MUL:
57025712
case GGML_OP_SCALE:
@@ -5828,6 +5838,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
58285838
}, dryrun);
58295839
}
58305840

5841+
static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
5842+
const uint32_t src0_type_size = ggml_type_size(src0->type);
5843+
const uint32_t src1_type_size = ggml_type_size(src1->type);
5844+
const uint32_t dst_type_size = ggml_type_size(dst->type);
5845+
5846+
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, {
5847+
(uint32_t)ggml_nelements(src0),
5848+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5849+
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
5850+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
5851+
0,
5852+
0.0f, 0.0f, 0,
5853+
}, dryrun);
5854+
}
5855+
58315856
static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
58325857
const uint32_t src0_type_size = ggml_type_size(src0->type);
58335858
const uint32_t src1_type_size = ggml_type_size(src1->type);
@@ -7120,6 +7145,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71207145
case GGML_OP_GET_ROWS:
71217146
case GGML_OP_ADD:
71227147
case GGML_OP_ACC:
7148+
case GGML_OP_SUB:
71237149
case GGML_OP_MUL:
71247150
case GGML_OP_DIV:
71257151
case GGML_OP_CONCAT:
@@ -7174,6 +7200,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71747200
case GGML_OP_ACC:
71757201
case GGML_OP_GET_ROWS:
71767202
case GGML_OP_ADD:
7203+
case GGML_OP_SUB:
71777204
case GGML_OP_MUL:
71787205
case GGML_OP_DIV:
71797206
case GGML_OP_CONCAT:
@@ -7230,6 +7257,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72307257
case GGML_OP_ADD:
72317258
ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun);
72327259

7260+
break;
7261+
case GGML_OP_SUB:
7262+
ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun);
7263+
72337264
break;
72347265
case GGML_OP_MUL:
72357266
ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7414,6 +7445,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74147445
case GGML_OP_ADD:
74157446
case GGML_OP_ACC:
74167447
case GGML_OP_GET_ROWS:
7448+
case GGML_OP_SUB:
74177449
case GGML_OP_MUL:
74187450
case GGML_OP_DIV:
74197451
case GGML_OP_CONCAT:
@@ -8358,6 +8390,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83588390
return ggml_is_contiguous(op->src[0]);
83598391
case GGML_OP_ADD:
83608392
case GGML_OP_ACC:
8393+
case GGML_OP_SUB:
83618394
case GGML_OP_MUL:
83628395
case GGML_OP_DIV:
83638396
case GGML_OP_CONCAT:
@@ -8854,6 +8887,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88548887
tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
88558888
} else if (tensor->op == GGML_OP_MUL_MAT_ID) {
88568889
tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
8890+
} else if (tensor->op == GGML_OP_SUB) {
8891+
tensor_clone = ggml_sub(ggml_ctx, src0_clone, src1_clone);
88578892
} else if (tensor->op == GGML_OP_MUL) {
88588893
tensor_clone = ggml_mul(ggml_ctx, src0_clone, src1_clone);
88598894
} else if (tensor->op == GGML_OP_DIV) {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#version 450
2+
3+
#extension GL_EXT_shader_16bit_storage : require
4+
5+
#include "types.comp"
6+
#include "generic_binary_head.comp"
7+
8+
const uint num_threads = 256;
9+
10+
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
11+
12+
void main() {
13+
uint idx = get_idx();
14+
15+
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
16+
const uint num_iter = 2;
17+
18+
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
19+
if (idx >= p.ne) {
20+
continue;
21+
}
22+
uint i00, i01, i02, i03;
23+
get_indices(idx, i00, i01, i02, i03);
24+
25+
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
26+
27+
idx += num_threads;
28+
}
29+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ void process_shaders() {
443443
string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
444444
string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
445445

446+
string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
447+
446448
string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
447449

448450
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});

tests/test-backend-ops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1511,6 +1511,7 @@ struct test_cont : public test_case {
15111511
};
15121512

15131513
// GGML_OP_ADD
1514+
// GGML_OP_SUB
15141515
// GGML_OP_MUL
15151516
// GGML_OP_DIV
15161517
struct test_bin_bcast : public test_case {
@@ -3938,7 +3939,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39383939
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));
39393940

39403941
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
3941-
for (auto op : {ggml_add, ggml_mul, ggml_div}) {
3942+
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
39423943
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
39433944
}
39443945
};

0 commit comments

Comments
 (0)