@@ -262,6 +262,7 @@ struct vk_device_struct {
262262 vk_pipeline pipeline_timestep_embedding_f32;
263263 vk_pipeline pipeline_pool2d_f32;
264264 vk_pipeline pipeline_rwkv_wkv6_f32;
265+ vk_pipeline pipeline_opt_step_adamw_f32;
265266
266267 // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
267268 vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2 ][2 ][2 ];
@@ -2227,6 +2228,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22272228
22282229 ggml_vk_create_pipeline (device, device->pipeline_rwkv_wkv6_f32 , " rwkv_wkv6_f32" , rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, " main" , 7 , sizeof (vk_op_rwkv_wkv6_push_constants), {1 , 1 , 1 }, {device->subgroup_size }, 1 );
22292230
2231+ ggml_vk_create_pipeline (device, device->pipeline_opt_step_adamw_f32 , " opt_step_adamw_f32" , opt_step_adamw_f32_len, opt_step_adamw_f32_data, " main" , 5 , sizeof (vk_op_push_constants), {512 , 1 , 1 }, {}, 1 );
2232+
22302233 for (auto &c : compiles) {
22312234 c.wait ();
22322235 }
@@ -5411,6 +5414,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
54115414 return ctx->device ->pipeline_rwkv_wkv6_f32 ;
54125415 }
54135416 return nullptr ;
5417+ case GGML_OP_OPT_STEP_ADAMW:
5418+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5419+ return ctx->device ->pipeline_opt_step_adamw_f32 ;
5420+ }
5421+ return nullptr ;
54145422 case GGML_OP_LEAKY_RELU:
54155423 if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
54165424 return ctx->device ->pipeline_leaky_relu_f32 ;
@@ -6019,6 +6027,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
60196027 );
60206028}
60216029
6030+ static void ggml_vk_op_f32_opt_step_adamw (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false ) {
6031+ const ggml_tensor * x = dst->src [0 ];
6032+ const ggml_tensor * g = dst->src [1 ];
6033+ const ggml_tensor * gm = dst->src [2 ];
6034+ const ggml_tensor * gv = dst->src [3 ];
6035+ const ggml_tensor * p = dst->src [4 ];
6036+
6037+ GGML_ASSERT (x->type == GGML_TYPE_F32);
6038+ GGML_ASSERT (g->type == GGML_TYPE_F32);
6039+ GGML_ASSERT (gm->type == GGML_TYPE_F32);
6040+ GGML_ASSERT (gv->type == GGML_TYPE_F32);
6041+ GGML_ASSERT (p->type == GGML_TYPE_F32);
6042+ GGML_ASSERT (dst->buffer != nullptr );
6043+ GGML_ASSERT (ggml_is_contiguous (x));
6044+ GGML_ASSERT (ggml_is_contiguous (g));
6045+ GGML_ASSERT (ggml_is_contiguous (gm));
6046+ GGML_ASSERT (ggml_is_contiguous (gv));
6047+ GGML_ASSERT (ggml_is_contiguous (p));
6048+ GGML_ASSERT (ggml_are_same_shape (x, g));
6049+ GGML_ASSERT (ggml_are_same_shape (x, gm));
6050+ GGML_ASSERT (ggml_are_same_shape (x, gv));
6051+ GGML_ASSERT (ggml_nelements (p) == 7 );
6052+
6053+ vk_pipeline pipeline = ggml_vk_op_get_pipeline (ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
6054+ GGML_ASSERT (pipeline != nullptr );
6055+
6056+ if (dryrun) {
6057+ ggml_pipeline_request_descriptor_sets (ctx->device , pipeline, 1 );
6058+ return ;
6059+ }
6060+
6061+ ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer ->context ;
6062+ ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer ->context ;
6063+ ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer ->context ;
6064+ ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer ->context ;
6065+ ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer ->context ;
6066+
6067+ ggml_vk_sync_buffers (subctx);
6068+
6069+ vk_buffer d_X = nullptr , d_G = nullptr , d_GM = nullptr , d_GV = nullptr , d_P = nullptr ;
6070+ size_t x_offset = 0 , g_offset = 0 , gm_offset = 0 , gv_offset = 0 , p_offset = 0 ;
6071+ bool X_uma = false , G_uma = false , GM_uma = false , GV_uma = false , P_uma = false ;
6072+
6073+ if (ctx->device ->uma ) {
6074+ ggml_vk_host_get (ctx->device , x->data , d_X, x_offset);
6075+ ggml_vk_host_get (ctx->device , g->data , d_G, g_offset);
6076+ ggml_vk_host_get (ctx->device , gm->data , d_GM, gm_offset);
6077+ ggml_vk_host_get (ctx->device , gv->data , d_GV, gv_offset);
6078+ ggml_vk_host_get (ctx->device , p->data , d_P, p_offset);
6079+
6080+ X_uma = d_X != nullptr ;
6081+ G_uma = d_G != nullptr ;
6082+ GM_uma = d_GM != nullptr ;
6083+ GV_uma = d_GV != nullptr ;
6084+ P_uma = d_P != nullptr ;
6085+ }
6086+
6087+ if (!X_uma) {
6088+ d_X = x_buf_ctx->dev_buffer ;
6089+ x_offset = vk_tensor_offset (x) + x->view_offs ;
6090+ }
6091+ if (!G_uma) {
6092+ d_G = g_buf_ctx->dev_buffer ;
6093+ g_offset = vk_tensor_offset (g) + g->view_offs ;
6094+ }
6095+ if (!GM_uma) {
6096+ d_GM = gm_buf_ctx->dev_buffer ;
6097+ gm_offset = vk_tensor_offset (gm) + gm->view_offs ;
6098+ }
6099+ if (!GV_uma) {
6100+ d_GV = gv_buf_ctx->dev_buffer ;
6101+ gv_offset = vk_tensor_offset (gv) + gv->view_offs ;
6102+ }
6103+ if (!P_uma) {
6104+ d_P = p_buf_ctx->dev_buffer ;
6105+ p_offset = vk_tensor_offset (p) + p->view_offs ;
6106+ }
6107+
6108+ const uint64_t x_size = ggml_nbytes (x);
6109+ const uint64_t g_size = ggml_nbytes (g);
6110+ const uint64_t gm_size = ggml_nbytes (gm);
6111+ const uint64_t gv_size = ggml_nbytes (gv);
6112+ const uint64_t p_size = ggml_nbytes (p);
6113+
6114+ std::array<uint32_t , 3 > elements = { (uint32_t )ggml_nelements (x), 1 , 1 };
6115+
6116+ ggml_vk_dispatch_pipeline (ctx, subctx, pipeline, {
6117+ vk_subbuffer{ d_X, x_offset, x_size },
6118+ vk_subbuffer{ d_G, g_offset, g_size },
6119+ vk_subbuffer{ d_GM, gm_offset, gm_size },
6120+ vk_subbuffer{ d_GV, gv_offset, gv_size },
6121+ vk_subbuffer{ d_P, p_offset, p_size },
6122+ }, sizeof (vk_op_push_constants), &pc, elements);
6123+ }
6124+
6125+ static void ggml_vk_opt_step_adamw (ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false ) {
6126+ const size_t n = ggml_nelements (dst->src [0 ]);
6127+
6128+ ggml_vk_op_f32_opt_step_adamw (
6129+ ctx, subctx, dst,
6130+ { (uint32_t )n, 0 , 0 .0f , 0 .0f },
6131+ dryrun
6132+ );
6133+ }
6134+
60226135static void ggml_vk_concat (ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false ) {
60236136 int * op_params = (int *)dst->op_params ;
60246137
@@ -7191,6 +7304,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71917304 case GGML_OP_RWKV_WKV6:
71927305 case GGML_OP_LEAKY_RELU:
71937306 case GGML_OP_FLASH_ATTN_EXT:
7307+ case GGML_OP_OPT_STEP_ADAMW:
71947308 break ;
71957309 default :
71967310 std::cerr << " ggml_vulkan: Error: Missing op: " << ggml_op_name (node->op ) << std::endl;
@@ -7413,6 +7527,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74137527 case GGML_OP_RWKV_WKV6:
74147528 ggml_vk_rwkv_wkv6 (ctx, compute_ctx, node, dryrun);
74157529
7530+ break ;
7531+
7532+ case GGML_OP_OPT_STEP_ADAMW:
7533+ ggml_vk_opt_step_adamw (ctx, compute_ctx, node, dryrun);
7534+
74167535 break ;
74177536 default :
74187537 return false ;
@@ -7500,6 +7619,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
75007619 case GGML_OP_RWKV_WKV6:
75017620 case GGML_OP_LEAKY_RELU:
75027621 case GGML_OP_REPEAT:
7622+ case GGML_OP_OPT_STEP_ADAMW:
75037623 buf = tensor->buffer ;
75047624
75057625 break ;
@@ -8433,6 +8553,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84338553 case GGML_OP_POOL_2D:
84348554 case GGML_OP_RWKV_WKV6:
84358555 case GGML_OP_LEAKY_RELU:
8556+ case GGML_OP_OPT_STEP_ADAMW:
84368557 return true ;
84378558 default :
84388559 return false ;
@@ -9048,6 +9169,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
90489169 } else if (tensor->op == GGML_OP_RWKV_WKV6) {
90499170 tensor_clone = ggml_rwkv_wkv6 (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ], tensor->src [3 ],
90509171 tensor->src [4 ], tensor->src [5 ]);
9172+ } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9173+ tensor_clone = ggml_opt_step_adamw (ggml_ctx, tensor->src [0 ], tensor->src [1 ], tensor->src [2 ],
9174+ tensor->src [3 ], tensor->src [4 ]);
90519175 }
90529176 else {
90539177 std::cerr << " Missing vk_check_results OP: " << ggml_op_name (tensor->op ) << std::endl;
0 commit comments