diff --git a/common.hpp b/common.hpp index 33d499fb1..74b218ab7 100644 --- a/common.hpp +++ b/common.hpp @@ -194,10 +194,12 @@ class GEGLU : public UnaryBlock { auto proj = std::dynamic_pointer_cast(blocks["proj"]); x = proj->forward(ctx, x); // [ne3, ne2, ne1, dim_out*2] - auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0); + auto x_vec = ggml_ext_chunk(ctx->ggml_ctx, x, 2, 0, false); x = x_vec[0]; // [ne3, ne2, ne1, dim_out] auto gate = x_vec[1]; // [ne3, ne2, ne1, dim_out] + gate = ggml_cont(ctx->ggml_ctx, gate); + gate = ggml_gelu_inplace(ctx->ggml_ctx, gate); x = ggml_mul(ctx->ggml_ctx, x, gate); // [ne3, ne2, ne1, dim_out] diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 07b9bfbf0..26dff4993 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -732,34 +732,22 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_ext_slice(struct ggml_context* ctx, __STATIC_INLINE__ std::vector ggml_ext_chunk(struct ggml_context* ctx, struct ggml_tensor* x, int num, - int64_t dim) { + int64_t dim, + bool cont = true) { GGML_ASSERT(dim >= 0 && dim < 4); GGML_ASSERT(x->ne[dim] % num == 0); - int perm[4] = {0, 1, 2, 3}; - for (int i = dim; i < 3; ++i) - perm[i] = perm[i + 1]; - perm[3] = dim; - - int inv_perm[4]; - for (int i = 0; i < 4; ++i) - inv_perm[perm[i]] = i; - - if (dim != 3) { - x = ggml_ext_torch_permute(ctx, x, perm[0], perm[1], perm[2], perm[3]); - x = ggml_cont(ctx, x); - } - std::vector chunks; - int64_t chunk_size = x->ne[3] / num; + int64_t chunk_size = x->ne[dim] / num; + int64_t stride = chunk_size * x->nb[dim]; + int64_t chunk_ne[4] = {x->ne[0], x->ne[1], x->ne[2], x->ne[3]}; + chunk_ne[dim] = chunk_size; for (int i = 0; i < num; i++) { auto chunk = ggml_view_4d( ctx, x, - x->ne[0], x->ne[1], x->ne[2], chunk_size, - x->nb[1], x->nb[2], x->nb[3], x->nb[3] * i * chunk_size); - - if (dim != 3) { - chunk = ggml_ext_torch_permute(ctx, chunk, inv_perm[0], inv_perm[1], inv_perm[2], inv_perm[3]); + chunk_ne[0], chunk_ne[1], chunk_ne[2], chunk_ne[3], + x->nb[1], x->nb[2], x->nb[3], stride * i); + if (cont) { chunk = ggml_cont(ctx, chunk); } chunks.push_back(chunk); @@ -772,7 +760,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* // x: [ne3, ne2, ne1, ne0] // return: [ne3, ne2, ne1, ne0/2] - auto x_vec = ggml_ext_chunk(ctx, x, 2, 0); + auto x_vec = ggml_ext_chunk(ctx, x, 2, 0, false); ggml_tensor* gate; if (gate_first) { gate = x_vec[0]; @@ -781,7 +769,7 @@ __STATIC_INLINE__ ggml_tensor* ggml_ext_silu_act(ggml_context* ctx, ggml_tensor* x = x_vec[0]; gate = x_vec[1]; } - + gate = ggml_cont(ctx, gate); gate = ggml_silu_inplace(ctx, gate); x = ggml_mul(ctx, x, gate); // [ne3, ne2, ne1, ne0/2]