Skip to content

Commit 750e1a4

Browse files
vulkan: implement GGML_OP_OPT_STEP_ADAMW
1 parent 7a6cb87 commit 750e1a4

File tree

3 files changed

+168
-0
lines changed

3 files changed

+168
-0
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ struct vk_device_struct {
262262
vk_pipeline pipeline_timestep_embedding_f32;
263263
vk_pipeline pipeline_pool2d_f32;
264264
vk_pipeline pipeline_rwkv_wkv6_f32;
265+
vk_pipeline pipeline_opt_step_adamw_f32;
265266

266267
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
267268
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
@@ -2227,6 +2228,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
22272228

22282229
ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
22292230

2231+
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2232+
22302233
for (auto &c : compiles) {
22312234
c.wait();
22322235
}
@@ -5411,6 +5414,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
54115414
return ctx->device->pipeline_rwkv_wkv6_f32;
54125415
}
54135416
return nullptr;
5417+
case GGML_OP_OPT_STEP_ADAMW:
5418+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5419+
return ctx->device->pipeline_opt_step_adamw_f32;
5420+
}
5421+
return nullptr;
54145422
case GGML_OP_LEAKY_RELU:
54155423
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
54165424
return ctx->device->pipeline_leaky_relu_f32;
@@ -6019,6 +6027,111 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
60196027
);
60206028
}
60216029

6030+
static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) {
6031+
const ggml_tensor * x = dst->src[0];
6032+
const ggml_tensor * g = dst->src[1];
6033+
const ggml_tensor * gm = dst->src[2];
6034+
const ggml_tensor * gv = dst->src[3];
6035+
const ggml_tensor * p = dst->src[4];
6036+
6037+
GGML_ASSERT(x->type == GGML_TYPE_F32);
6038+
GGML_ASSERT(g->type == GGML_TYPE_F32);
6039+
GGML_ASSERT(gm->type == GGML_TYPE_F32);
6040+
GGML_ASSERT(gv->type == GGML_TYPE_F32);
6041+
GGML_ASSERT(p->type == GGML_TYPE_F32);
6042+
GGML_ASSERT(dst->buffer != nullptr);
6043+
GGML_ASSERT(ggml_is_contiguous(x));
6044+
GGML_ASSERT(ggml_is_contiguous(g));
6045+
GGML_ASSERT(ggml_is_contiguous(gm));
6046+
GGML_ASSERT(ggml_is_contiguous(gv));
6047+
GGML_ASSERT(ggml_is_contiguous(p));
6048+
GGML_ASSERT(ggml_are_same_shape(x, g));
6049+
GGML_ASSERT(ggml_are_same_shape(x, gm));
6050+
GGML_ASSERT(ggml_are_same_shape(x, gv));
6051+
GGML_ASSERT(ggml_nelements(p) == 7);
6052+
6053+
vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW);
6054+
GGML_ASSERT(pipeline != nullptr);
6055+
6056+
if (dryrun) {
6057+
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6058+
return;
6059+
}
6060+
6061+
ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context;
6062+
ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context;
6063+
ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context;
6064+
ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context;
6065+
ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context;
6066+
6067+
ggml_vk_sync_buffers(subctx);
6068+
6069+
vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr;
6070+
size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0;
6071+
bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false;
6072+
6073+
if (ctx->device->uma) {
6074+
ggml_vk_host_get(ctx->device, x->data, d_X, x_offset);
6075+
ggml_vk_host_get(ctx->device, g->data, d_G, g_offset);
6076+
ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset);
6077+
ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset);
6078+
ggml_vk_host_get(ctx->device, p->data, d_P, p_offset);
6079+
6080+
X_uma = d_X != nullptr;
6081+
G_uma = d_G != nullptr;
6082+
GM_uma = d_GM != nullptr;
6083+
GV_uma = d_GV != nullptr;
6084+
P_uma = d_P != nullptr;
6085+
}
6086+
6087+
if (!X_uma) {
6088+
d_X = x_buf_ctx->dev_buffer;
6089+
x_offset = vk_tensor_offset(x) + x->view_offs;
6090+
}
6091+
if (!G_uma) {
6092+
d_G = g_buf_ctx->dev_buffer;
6093+
g_offset = vk_tensor_offset(g) + g->view_offs;
6094+
}
6095+
if (!GM_uma) {
6096+
d_GM = gm_buf_ctx->dev_buffer;
6097+
gm_offset = vk_tensor_offset(gm) + gm->view_offs;
6098+
}
6099+
if (!GV_uma) {
6100+
d_GV = gv_buf_ctx->dev_buffer;
6101+
gv_offset = vk_tensor_offset(gv) + gv->view_offs;
6102+
}
6103+
if (!P_uma) {
6104+
d_P = p_buf_ctx->dev_buffer;
6105+
p_offset = vk_tensor_offset(p) + p->view_offs;
6106+
}
6107+
6108+
const uint64_t x_size = ggml_nbytes(x);
6109+
const uint64_t g_size = ggml_nbytes(g);
6110+
const uint64_t gm_size = ggml_nbytes(gm);
6111+
const uint64_t gv_size = ggml_nbytes(gv);
6112+
const uint64_t p_size = ggml_nbytes(p);
6113+
6114+
std::array<uint32_t, 3> elements = { (uint32_t)ggml_nelements(x), 1, 1 };
6115+
6116+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6117+
vk_subbuffer{ d_X, x_offset, x_size },
6118+
vk_subbuffer{ d_G, g_offset, g_size },
6119+
vk_subbuffer{ d_GM, gm_offset, gm_size },
6120+
vk_subbuffer{ d_GV, gv_offset, gv_size },
6121+
vk_subbuffer{ d_P, p_offset, p_size },
6122+
}, sizeof(vk_op_push_constants), &pc, elements);
6123+
}
6124+
6125+
static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6126+
const size_t n = ggml_nelements(dst->src[0]);
6127+
6128+
ggml_vk_op_f32_opt_step_adamw(
6129+
ctx, subctx, dst,
6130+
{ (uint32_t)n, 0, 0.0f, 0.0f },
6131+
dryrun
6132+
);
6133+
}
6134+
60226135
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
60236136
int * op_params = (int *)dst->op_params;
60246137

@@ -7191,6 +7304,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
71917304
case GGML_OP_RWKV_WKV6:
71927305
case GGML_OP_LEAKY_RELU:
71937306
case GGML_OP_FLASH_ATTN_EXT:
7307+
case GGML_OP_OPT_STEP_ADAMW:
71947308
break;
71957309
default:
71967310
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -7413,6 +7527,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
74137527
case GGML_OP_RWKV_WKV6:
74147528
ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
74157529

7530+
break;
7531+
7532+
case GGML_OP_OPT_STEP_ADAMW:
7533+
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7534+
74167535
break;
74177536
default:
74187537
return false;
@@ -7500,6 +7619,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
75007619
case GGML_OP_RWKV_WKV6:
75017620
case GGML_OP_LEAKY_RELU:
75027621
case GGML_OP_REPEAT:
7622+
case GGML_OP_OPT_STEP_ADAMW:
75037623
buf = tensor->buffer;
75047624

75057625
break;
@@ -8433,6 +8553,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
84338553
case GGML_OP_POOL_2D:
84348554
case GGML_OP_RWKV_WKV6:
84358555
case GGML_OP_LEAKY_RELU:
8556+
case GGML_OP_OPT_STEP_ADAMW:
84368557
return true;
84378558
default:
84388559
return false;
@@ -9048,6 +9169,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
90489169
} else if (tensor->op == GGML_OP_RWKV_WKV6) {
90499170
tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
90509171
tensor->src[4], tensor->src[5]);
9172+
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9173+
tensor_clone = ggml_opt_step_adamw(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2],
9174+
tensor->src[3], tensor->src[4]);
90519175
}
90529176
else {
90539177
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#version 450
2+
3+
#include "generic_head.comp"
4+
#include "types.comp"
5+
6+
#extension GL_EXT_control_flow_attributes : enable
7+
8+
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
9+
10+
layout (binding = 0) buffer X {A_TYPE x[];};
11+
layout (binding = 1) readonly buffer G {A_TYPE grad[];};
12+
layout (binding = 2) buffer GM {A_TYPE gradm[];};
13+
layout (binding = 3) buffer GV {A_TYPE gradv[];};
14+
layout (binding = 4) readonly buffer P {float params[7];};
15+
16+
void main() {
17+
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
18+
19+
if (i >= p.KX) {
20+
return;
21+
}
22+
23+
const float alpha = params[0];
24+
const float beta1 = params[1];
25+
const float beta2 = params[2];
26+
const float eps = params[3];
27+
const float wd = params[4];
28+
const float beta1h = params[5];
29+
const float beta2h = params[6];
30+
31+
const float gi = grad[i];
32+
const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1);
33+
const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2);
34+
35+
gradm[i] = gmi;
36+
gradv[i] = gvi;
37+
38+
const float mh = gmi*beta1h;
39+
const float vh = sqrt(gvi*beta2h) + eps;
40+
41+
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
42+
}

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,8 @@ void process_shaders() {
517517

518518
string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
519519

520+
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
521+
520522
for (auto &c : compiles) {
521523
c.wait();
522524
}

0 commit comments

Comments
 (0)