@@ -257,6 +257,7 @@ struct vk_device_struct {
257257 vk_pipeline pipeline_argsort_f32;
258258 vk_pipeline pipeline_sum_rows_f32;
259259 vk_pipeline pipeline_argmax_f32;
260+ vk_pipeline pipeline_count_equal_i32;
260261 vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
261262 vk_pipeline pipeline_timestep_embedding_f32;
262263 vk_pipeline pipeline_pool2d_f32;
@@ -2211,6 +2212,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22112212
22122213 ggml_vk_create_pipeline (device, device->pipeline_sum_rows_f32 , " sum_rows_f32" , sum_rows_f32_len, sum_rows_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {1 , 1 , 1 }, { device->subgroup_size }, 1 );
22132214
2215+ ggml_vk_create_pipeline (device, device->pipeline_count_equal_i32 , " count_equal_i32" , count_equal_i32_len, count_equal_i32_data, " main" , 3 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 );
2216+
22142217 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32 , " im2col_f32" , im2col_f32_len, im2col_f32_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
22152218 if (device->float_controls_rte_fp16 ) {
22162219 ggml_vk_create_pipeline (device, device->pipeline_im2col_f32_f16 , " im2col_f32_f16" , im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, " main" , 2 , sizeof (vk_op_im2col_push_constants), {512 , 1 , 1 }, { device->subgroup_size }, 1 , true );
@@ -5380,6 +5383,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
53805383 return ctx->device ->pipeline_argmax_f32 ;
53815384 }
53825385 return nullptr ;
5386+ case GGML_OP_COUNT_EQUAL:
5387+ if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) {
5388+ return ctx->device ->pipeline_count_equal_i32 ;
5389+ }
5390+ return nullptr ;
53835391 case GGML_OP_IM2COL:
53845392 if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
53855393 return ctx->device ->pipeline_im2col_f32 ;
@@ -6278,6 +6286,11 @@ static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, co
62786286 ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_ARGMAX, { (uint32_t )src0->ne [0 ], 0 , 0 .0f , 0 .0f }, dryrun);
62796287}
62806288
6289+ static void ggml_vk_count_equal (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
6290+ ggml_backend_tensor_memset (dst, 0 , 0 , ggml_nbytes (dst));
6291+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr , dst, GGML_OP_COUNT_EQUAL, { (uint32_t )ggml_nelements (src0), 0 , 0 .0f , 0 .0f }, dryrun);
6292+ }
6293+
62816294static void ggml_vk_im2col (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
62826295 const int32_t s0 = dst->op_params [0 ];
62836296 const int32_t s1 = dst->op_params [1 ];
@@ -7171,6 +7184,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71717184 case GGML_OP_SUM:
71727185 case GGML_OP_SUM_ROWS:
71737186 case GGML_OP_ARGMAX:
7187+ case GGML_OP_COUNT_EQUAL:
71747188 case GGML_OP_IM2COL:
71757189 case GGML_OP_TIMESTEP_EMBEDDING:
71767190 case GGML_OP_POOL_2D:
@@ -7225,6 +7239,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72257239 case GGML_OP_SUM:
72267240 case GGML_OP_SUM_ROWS:
72277241 case GGML_OP_ARGMAX:
7242+ case GGML_OP_COUNT_EQUAL:
72287243 case GGML_OP_IM2COL:
72297244 case GGML_OP_TIMESTEP_EMBEDDING:
72307245 case GGML_OP_POOL_2D:
@@ -7360,6 +7375,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73607375 case GGML_OP_ARGMAX:
73617376 ggml_vk_argmax (ctx, compute_ctx, src0, node, dryrun);
73627377
7378+ break ;
7379+ case GGML_OP_COUNT_EQUAL:
7380+ ggml_vk_count_equal (ctx, compute_ctx, src0, src1, node, dryrun);
7381+
73637382 break ;
73647383 case GGML_OP_IM2COL:
73657384 ggml_vk_im2col (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7474,6 +7493,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
74747493 case GGML_OP_SUM:
74757494 case GGML_OP_SUM_ROWS:
74767495 case GGML_OP_ARGMAX:
7496+ case GGML_OP_COUNT_EQUAL:
74777497 case GGML_OP_IM2COL:
74787498 case GGML_OP_TIMESTEP_EMBEDDING:
74797499 case GGML_OP_POOL_2D:
@@ -8407,6 +8427,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84078427 case GGML_OP_SUM:
84088428 case GGML_OP_SUM_ROWS:
84098429 case GGML_OP_ARGMAX:
8430+ case GGML_OP_COUNT_EQUAL:
84108431 case GGML_OP_IM2COL:
84118432 case GGML_OP_TIMESTEP_EMBEDDING:
84128433 case GGML_OP_POOL_2D:
@@ -8995,6 +9016,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89959016 tensor_clone = ggml_sum_rows (ggml_ctx, src0_clone);
89969017 } else if (tensor->op == GGML_OP_ARGMAX) {
89979018 tensor_clone = ggml_argmax (ggml_ctx, src0_clone);
9019+ } else if (tensor->op == GGML_OP_COUNT_EQUAL) {
9020+ tensor_clone = ggml_count_equal (ggml_ctx, src0_clone, src1_clone);
89989021 } else if (tensor->op == GGML_OP_IM2COL) {
89999022 const int32_t s0 = tensor->op_params [0 ];
90009023 const int32_t s1 = tensor->op_params [1 ];
@@ -9114,6 +9137,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
91149137 } else if (tensor->type == GGML_TYPE_I32) {
91159138 correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3 ] + i2*comp_nb[2 ] + i1*comp_nb[1 ] + i0*comp_nb[0 ]);
91169139 result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb [3 ] + i2*tensor->nb [2 ] + i1*tensor->nb [1 ] + i0*tensor->nb [0 ]);
9140+ } else if (tensor->type == GGML_TYPE_I64) {
9141+ correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3 ] + i2*comp_nb[2 ] + i1*comp_nb[1 ] + i0*comp_nb[0 ]);
9142+ result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb [3 ] + i2*tensor->nb [2 ] + i1*tensor->nb [1 ] + i0*tensor->nb [0 ]);
91179143 } else {
91189144 std::cerr << " Results check not implemented for type " << ggml_type_name (tensor->type ) << std::endl;
91199145 }
0 commit comments