@@ -233,7 +233,7 @@ struct vk_device_struct {
233233 vk_pipeline pipeline_cos_f32;
234234 vk_pipeline pipeline_clamp_f32;
235235 vk_pipeline pipeline_pad_f32;
236- vk_pipeline pipeline_repeat_f32;
236+ vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32 ;
237237 vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
238238 vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
239239 vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
@@ -2175,6 +2175,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
21752175 ggml_vk_create_pipeline (device, device->pipeline_pad_f32 , " pad_f32" , pad_f32_len, pad_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
21762176
21772177 ggml_vk_create_pipeline (device, device->pipeline_repeat_f32 , " repeat_f32" , repeat_f32_len, repeat_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
2178+ ggml_vk_create_pipeline (device, device->pipeline_repeat_back_f32 , " repeat_back_f32" , repeat_back_f32_len, repeat_back_f32_data, " main" , 2 , sizeof (vk_op_unary_push_constants), {512 , 1 , 1 }, {}, 1 );
21782179
21792180 ggml_vk_create_pipeline (device, device->pipeline_gelu_f32 , " gelu_f32" , gelu_f32_len, gelu_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
21802181 ggml_vk_create_pipeline (device, device->pipeline_gelu_quick_f32 , " gelu_quick_f32" , gelu_quick_f32_len, gelu_quick_f32_data, " main" , 2 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
@@ -5267,6 +5268,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
52675268 return ctx->device ->pipeline_repeat_f32 ;
52685269 }
52695270 return nullptr ;
5271+ case GGML_OP_REPEAT_BACK:
5272+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5273+ return ctx->device ->pipeline_repeat_back_f32 ;
5274+ }
5275+ return nullptr ;
52705276 case GGML_OP_CPY:
52715277 case GGML_OP_CONT:
52725278 case GGML_OP_DUP:
@@ -5447,6 +5453,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
54475453 case GGML_OP_CLAMP:
54485454 case GGML_OP_PAD:
54495455 case GGML_OP_REPEAT:
5456+ case GGML_OP_REPEAT_BACK:
54505457 case GGML_OP_ROPE:
54515458 return true ;
54525459 default :
@@ -5732,6 +5739,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
57325739 case GGML_OP_CLAMP:
57335740 case GGML_OP_PAD:
57345741 case GGML_OP_REPEAT:
5742+ case GGML_OP_REPEAT_BACK:
57355743 case GGML_OP_CPY:
57365744 case GGML_OP_CONCAT:
57375745 case GGML_OP_UPSCALE:
@@ -6265,6 +6273,20 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
62656273 }, dryrun);
62666274}
62676275
6276+ static void ggml_vk_repeat_back (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
6277+ const uint32_t src0_type_size = ggml_type_size (src0->type );
6278+ const uint32_t dst_type_size = ggml_type_size (dst->type );
6279+
6280+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr , nullptr , dst, GGML_OP_REPEAT_BACK, {
6281+ (uint32_t )ggml_nelements (dst),
6282+ (uint32_t )src0->ne [0 ], (uint32_t )src0->ne [1 ], (uint32_t )src0->ne [2 ], (uint32_t )src0->ne [3 ], (uint32_t )src0->nb [0 ] / src0_type_size, (uint32_t )src0->nb [1 ] / src0_type_size, (uint32_t )src0->nb [2 ] / src0_type_size, (uint32_t )src0->nb [3 ] / src0_type_size,
6283+ (uint32_t ) dst->ne [0 ], (uint32_t ) dst->ne [1 ], (uint32_t ) dst->ne [2 ], (uint32_t ) dst->ne [3 ], (uint32_t ) dst->nb [0 ] / dst_type_size, (uint32_t ) dst->nb [1 ] / dst_type_size, (uint32_t ) dst->nb [2 ] / dst_type_size, (uint32_t ) dst->nb [3 ] / dst_type_size,
6284+ 0 ,
6285+ 0 .0f , 0 .0f ,
6286+ 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ,
6287+ }, dryrun);
6288+ }
6289+
62686290static void ggml_vk_cpy (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false ) {
62696291 const uint32_t src0_type_size = ggml_type_size (src0->type );
62706292 const uint32_t dst_type_size = ggml_type_size (dst->type );
@@ -7268,6 +7290,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
72687290 }
72697291 break ;
72707292 case GGML_OP_REPEAT:
7293+ case GGML_OP_REPEAT_BACK:
72717294 case GGML_OP_GET_ROWS:
72727295 case GGML_OP_ADD:
72737296 case GGML_OP_ACC:
@@ -7325,6 +7348,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73257348 } else {
73267349 switch (node->op ) {
73277350 case GGML_OP_REPEAT:
7351+ case GGML_OP_REPEAT_BACK:
73287352 case GGML_OP_ACC:
73297353 case GGML_OP_GET_ROWS:
73307354 case GGML_OP_ADD:
@@ -7374,6 +7398,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
73747398 case GGML_OP_REPEAT:
73757399 ggml_vk_repeat (ctx, compute_ctx, src0, node, dryrun);
73767400
7401+ break ;
7402+ case GGML_OP_REPEAT_BACK:
7403+ ggml_vk_repeat_back (ctx, compute_ctx, src0, node, dryrun);
7404+
73777405 break ;
73787406 case GGML_OP_ACC:
73797407 ggml_vk_acc (ctx, compute_ctx, src0, src1, node, dryrun);
@@ -7619,6 +7647,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
76197647 case GGML_OP_RWKV_WKV6:
76207648 case GGML_OP_LEAKY_RELU:
76217649 case GGML_OP_REPEAT:
7650+ case GGML_OP_REPEAT_BACK:
76227651 case GGML_OP_OPT_STEP_ADAMW:
76237652 buf = tensor->buffer ;
76247653
@@ -8517,6 +8546,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
85178546 } break ;
85188547 case GGML_OP_REPEAT:
85198548 return ggml_type_size (op->type ) == sizeof (float ) && ggml_type_size (op->src [0 ]->type ) == sizeof (float );
8549+ case GGML_OP_REPEAT_BACK:
8550+ return op->type == GGML_TYPE_F32 && op->src [0 ]->type == GGML_TYPE_F32;
85208551 case GGML_OP_ROPE:
85218552 case GGML_OP_NONE:
85228553 case GGML_OP_RESHAPE:
@@ -8922,6 +8953,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
89228953 tensor_clone = ggml_pad (ggml_ctx, src_clone[0 ], tensor->ne [0 ] - src_clone[0 ]->ne [0 ], tensor->ne [1 ] - src_clone[0 ]->ne [1 ], tensor->ne [2 ] - src_clone[0 ]->ne [2 ], tensor->ne [3 ] - src_clone[0 ]->ne [3 ]);
89238954 } else if (tensor->op == GGML_OP_REPEAT) {
89248955 tensor_clone = ggml_repeat (ggml_ctx, src_clone[0 ], tensor);
8956+ } else if (tensor->op == GGML_OP_REPEAT_BACK) {
8957+ tensor_clone = ggml_repeat_back (ggml_ctx, src_clone[0 ], tensor);
89258958 } else if (tensor->op == GGML_OP_ADD) {
89268959 tensor_clone = ggml_add (ggml_ctx, src_clone[0 ], src_clone[1 ]);
89278960 } else if (tensor->op == GGML_OP_ACC) {
0 commit comments