@@ -222,6 +222,7 @@ struct vk_device_struct {
222222 vk_pipeline pipeline_acc_f32;
223223 vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
224224 vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
225+ vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
225226 vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
226227 vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
227228 vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
@@ -2148,6 +2149,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
21482149
21492150 ggml_vk_create_pipeline (device, device->pipeline_acc_f32 , " acc_f32" , acc_f32_len, acc_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {}, 1 );
21502151
2152+ ggml_vk_create_pipeline (device, device->pipeline_sub_f32 , " sub_f32" , sub_f32_len, sub_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
2153+ ggml_vk_create_pipeline (device, device->pipeline_sub_f32_norepeat , " sub_f32_norepeat" , sub_f32_len, sub_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
21512154 ggml_vk_create_pipeline (device, device->pipeline_mul_f32 , " mul_f32" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
21522155 ggml_vk_create_pipeline (device, device->pipeline_mul_f32_norepeat , " mul_f32_norepeat" , mul_f32_len, mul_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {1 }, 1 );
21532156 ggml_vk_create_pipeline (device, device->pipeline_div_f32 , " div_f32" , div_f32_len, div_f32_data, " main" , 3 , sizeof (vk_op_binary_push_constants), {512 , 1 , 1 }, {0 }, 1 );
@@ -5192,6 +5195,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
51925195 return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_add_f16_f32_f16_norepeat : ctx->device ->pipeline_add_f16_f32_f16 ;
51935196 }
51945197 return nullptr ;
5198+ case GGML_OP_SUB:
5199+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5200+ return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_sub_f32_norepeat : ctx->device ->pipeline_sub_f32 ;
5201+ }
5202+ return nullptr ;
51955203 case GGML_OP_MUL:
51965204 if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
51975205 return ggml_are_same_shape (src0, src1) ? ctx->device ->pipeline_mul_f32_norepeat : ctx->device ->pipeline_mul_f32 ;
@@ -5412,6 +5420,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
54125420 case GGML_OP_CPY:
54135421 case GGML_OP_GET_ROWS:
54145422 case GGML_OP_ADD:
5423+ case GGML_OP_SUB:
54155424 case GGML_OP_MUL:
54165425 case GGML_OP_DIV:
54175426 case GGML_OP_CONCAT:
@@ -5697,6 +5706,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
56975706 elements = { N * OC * OH * OW, 1 , 1 };
56985707 } break ;
56995708 case GGML_OP_ADD:
5709+ case GGML_OP_SUB:
57005710 case GGML_OP_DIV:
57015711 case GGML_OP_MUL:
57025712 case GGML_OP_SCALE:
@@ -5828,6 +5838,21 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const
58285838 }, dryrun);
58295839}
58305840
5841+ static void ggml_vk_sub (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
5842+ const uint32_t src0_type_size = ggml_type_size (src0->type );
5843+ const uint32_t src1_type_size = ggml_type_size (src1->type );
5844+ const uint32_t dst_type_size = ggml_type_size (dst->type );
5845+
5846+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_SUB, {
5847+ (uint32_t )ggml_nelements (src0),
5848+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ],(uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
5849+ (uint32_t )src1->ne [0 ], (uint32_t )src1->ne [1 ], (uint32_t )src1->ne [2 ],(uint32_t )src1->ne [3 ], (uint32_t )src1->nb [0 ] / src1_type_size, (uint32_t )src1->nb [1 ] / src1_type_size, (uint32_t )src1->nb [2 ] / src1_type_size, (uint32_t )src1->nb [3 ] / src1_type_size,
5850+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ],(uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
5851+ 0 ,
5852+ 0 .0f , 0 .0f , 0 ,
5853+ }, dryrun);
5854+ }
5855+
58315856static void ggml_vk_mul (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
58325857 const uint32_t src0_type_size = ggml_type_size (src0->type );
58335858 const uint32_t src1_type_size = ggml_type_size (src1->type );
@@ -7120,6 +7145,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71207145 case GGML_OP_GET_ROWS:
71217146 case GGML_OP_ADD:
71227147 case GGML_OP_ACC:
7148+ case GGML_OP_SUB:
71237149 case GGML_OP_MUL:
71247150 case GGML_OP_DIV:
71257151 case GGML_OP_CONCAT:
@@ -7174,6 +7200,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71747200 case GGML_OP_ACC:
71757201 case GGML_OP_GET_ROWS:
71767202 case GGML_OP_ADD:
7203+ case GGML_OP_SUB:
71777204 case GGML_OP_MUL:
71787205 case GGML_OP_DIV:
71797206 case GGML_OP_CONCAT:
@@ -7230,6 +7257,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72307257 case GGML_OP_ADD:
72317258 ggml_vk_add (ctx, compute_ctx, src0, src1, node, dryrun);
72327259
7260+ break ;
7261+ case GGML_OP_SUB:
7262+ ggml_vk_sub (ctx, compute_ctx, src0, src1, node, dryrun);
7263+
72337264 break ;
72347265 case GGML_OP_MUL:
72357266 ggml_vk_mul (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7414,6 +7445,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74147445 case GGML_OP_ADD:
74157446 case GGML_OP_ACC:
74167447 case GGML_OP_GET_ROWS:
7448+ case GGML_OP_SUB:
74177449 case GGML_OP_MUL:
74187450 case GGML_OP_DIV:
74197451 case GGML_OP_CONCAT:
@@ -8358,6 +8390,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
83588390 return ggml_is_contiguous (op->src [0 ]);
83598391 case GGML_OP_ADD:
83608392 case GGML_OP_ACC:
8393+ case GGML_OP_SUB:
83618394 case GGML_OP_MUL:
83628395 case GGML_OP_DIV:
83638396 case GGML_OP_CONCAT:
@@ -8854,6 +8887,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
88548887 tensor_clone = ggml_mul_mat (ggml_ctx, src0_clone, src1_clone);
88558888 } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
88568889 tensor_clone = ggml_mul_mat_id (ggml_ctx, src0_clone, src1_clone, src2_clone);
8890+ } else if (tensor->op == GGML_OP_SUB) {
8891+ tensor_clone = ggml_sub (ggml_ctx, src0_clone, src1_clone);
88578892 } else if (tensor->op == GGML_OP_MUL) {
88588893 tensor_clone = ggml_mul (ggml_ctx, src0_clone, src1_clone);
88598894 } else if (tensor->op == GGML_OP_DIV) {
0 commit comments