diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 20c912d0e9b..f71c09b6259 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -514,6 +514,7 @@ extern "C" { GGML_OP_SOFT_MAX, GGML_OP_SOFT_MAX_BACK, GGML_OP_ROPE, + GGML_OP_ROPE_COMP, GGML_OP_ROPE_BACK, GGML_OP_CLAMP, GGML_OP_CONV_TRANSPOSE_1D, @@ -1858,6 +1859,44 @@ extern "C" { GGML_API void ggml_rope_yarn_corr_dims( int n_dims, int n_ctx_orig, float freq_base, float beta_fast, float beta_slow, float dims[2]); + + enum ggml_rope_ordering { + GGML_ROPE_ORDERING_NORMAL, + GGML_ROPE_ORDERING_NEOX, + }; + + // RoPE composable API + GGML_API struct ggml_tensor * ggml_rope_comp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, // pos must be F32 + int32_t n_dims, + float freq_base, + enum ggml_rope_ordering ordering); + + GGML_API struct ggml_tensor * ggml_rope_comp_set_freq_factors( + struct ggml_context * ctx, + struct ggml_tensor * node, + struct ggml_tensor * freq_factors); + + // set YaRN parameters + GGML_API struct ggml_tensor * ggml_rope_comp_set_yarn( + struct ggml_context * ctx, + struct ggml_tensor * node, + int n_ctx_orig, + float freq_base, + float freq_scale, // == 1.0f / scale_factor + float ramp_factor, // usually 1.0f + float attn_factor, + float beta_fast, + float beta_slow); + + // set M-RoPE mode + GGML_API struct ggml_tensor * ggml_rope_comp_set_multi( + struct ggml_context * ctx, + struct ggml_tensor * node, + int sections[GGML_MROPE_SECTIONS]); + // rotary position embedding backward, i.e compute dx from dy // a - dy GGML_API struct ggml_tensor * ggml_rope_ext_back( diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index a59b5189389..30a138a42b5 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -1863,6 +1863,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_rope(params, tensor); } break; + case GGML_OP_ROPE_COMP: + { + ggml_compute_forward_rope_comp(params, tensor); + } break; case GGML_OP_ROPE_BACK: { ggml_compute_forward_rope_back(params, tensor); @@ -2294,6 +2298,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX_BACK: case GGML_OP_ROPE: + case GGML_OP_ROPE_COMP: case GGML_OP_ROPE_BACK: case GGML_OP_ADD_REL_POS: { @@ -2812,6 +2817,7 @@ struct ggml_cplan ggml_graph_plan( } break; case GGML_OP_SOFT_MAX: case GGML_OP_ROPE: + case GGML_OP_ROPE_COMP: case GGML_OP_ROPE_BACK: { cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3032783971d..54b72c24e65 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5817,6 +5817,154 @@ void ggml_compute_forward_rope( } } +// ggml_compute_forward_rope_comp + +template //float or ggml_fp16_t +static void ggml_compute_forward_rope_comp_flt( + const ggml_compute_params * params, + ggml_tensor * dst, + const bool forward) { + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * src2 = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + int32_t n_dims, idx_pair, idx_scale, idx_offset; + float theta_scale, yarn_high, yarn_low, freq_scale, ramp_factor, attn_factor; + int32_t sections[4]; + + memcpy(&n_dims, (int32_t *)dst->op_params + 0, sizeof(int32_t)); + memcpy(&idx_pair, (int32_t *)dst->op_params + 1, sizeof(int32_t)); + memcpy(&idx_scale, (int32_t *)dst->op_params + 2, sizeof(int32_t)); + memcpy(&idx_offset, (int32_t *)dst->op_params + 3, sizeof(int32_t)); + memcpy(&theta_scale, (int32_t *)dst->op_params + 4, sizeof(float)); + memcpy(&yarn_high, (int32_t *)dst->op_params + 5, sizeof(float)); + memcpy(&yarn_low, (int32_t *)dst->op_params + 6, sizeof(float)); + memcpy(&freq_scale, (int32_t *)dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *)dst->op_params + 8, sizeof(float)); + memcpy(&ramp_factor, (int32_t *)dst->op_params + 9, sizeof(float)); + memcpy(§ions[0], (int32_t *)dst->op_params + 10, sizeof(int32_t)); + memcpy(§ions[1], (int32_t *)dst->op_params + 11, sizeof(int32_t)); + memcpy(§ions[2], (int32_t *)dst->op_params + 12, sizeof(int32_t)); + memcpy(§ions[3], (int32_t *)dst->op_params + 13, sizeof(int32_t)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == nb00); + GGML_ASSERT(nb0 == sizeof(T)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(dst); + + GGML_ASSERT(n_dims <= ne0); + GGML_ASSERT(n_dims % 2 == 0); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + // row index used to determine which thread to use + int ir = 0; + + // TODO M-RoPE + + const float * freq_factors = NULL; + if (src2 != NULL) { + GGML_ASSERT(src2->type == GGML_TYPE_F32); + GGML_ASSERT(src2->ne[0] >= n_dims / 2); + freq_factors = (const float *) src2->data; + } + + // backward process uses inverse rotation by cos and sin. + // cos and sin build a rotation matrix, where the inverse is the transpose. + // this essentially just switches the sign of sin. + const float sin_sign = forward ? 1.0f : -1.0f; + + const float * pos = (const float *) src1->data; + + auto init_cache = [&](float * cache, float p) -> void { + for (int64_t i0 = 0; i0 < ne0; i0 += 2) { + const float freq_factor = freq_factors ? freq_factors[i0/2] : 1.0f; + + float theta = p * powf(theta_scale, i0/2) / freq_factor; + const float theta_extrap = theta; + const float theta_interp = freq_scale * theta; + + if (ramp_factor != 0.0f) { + const float ramp_mix = rope_yarn_ramp(yarn_low, yarn_high, i0) * ramp_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + } else { + theta = theta_interp; + } + + cache[i0 + 0] = cosf(theta) * attn_factor; + cache[i0 + 1] = sinf(theta) * attn_factor * sin_sign; + } + }; + + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch + for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len + + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + { + const float p = pos[i2]; + init_cache(cache, p); + } + // TODO M-RoPE + + for (int64_t i1 = idx_offset; i1 < ne1; i1++) { // attn-heads + if (ir++ < ir0) continue; + if (ir > ir1) break; + + T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); + T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); + + rotate_pairs(n_dims, idx_pair, cache, src, dst_data, idx_scale); + // TODO M-RoPE + + // fill the remain channels with data from src tensor + for (int64_t i0 = n_dims; i0 < ne0; i0 += 2) { + const T * const src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } // attn-heads + } + } +} + +void ggml_compute_forward_rope_comp( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F16: + { + ggml_compute_forward_rope_comp_flt(params, dst, true); + } break; + case GGML_TYPE_F32: + { + ggml_compute_forward_rope_comp_flt(params, dst, true); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} + // ggml_compute_forward_rope_back void ggml_compute_forward_rope_back( diff --git a/ggml/src/ggml-cpu/ops.h b/ggml/src/ggml-cpu/ops.h index 0fdfee79766..1926311e174 100644 --- a/ggml/src/ggml-cpu/ops.h +++ b/ggml/src/ggml-cpu/ops.h @@ -61,6 +61,7 @@ void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * para void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_rope_comp(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d38665..3ae8152174a 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -261,6 +261,7 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vectorop == GGML_OP_ROPE_COMP); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_rope_comp_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_IM2COL); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 0a8b9211a76..79fb530936b 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -136,6 +136,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rope_comp (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_im2col (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_conv_transpose_2d (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index f24270bb1c5..54c2b1395a2 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1029,6 +1029,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); case GGML_OP_ROPE: + case GGML_OP_ROPE_COMP: return true; case GGML_OP_IM2COL: return ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32 && (op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8944b07e907..fc1524ace3b 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -260,6 +260,40 @@ typedef struct { bool src2; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_dims; + int32_t idx_pair; + int32_t idx_scale; + int32_t idx_offset; + float theta_scale; + float yarn_high; + float yarn_low; + float freq_scale; + float ramp_factor; + float attn_factor; + int32_t sect_0; + int32_t sect_1; + int32_t sect_2; + int32_t sect_3; + bool src2; +} ggml_metal_kargs_rope_comp; + typedef struct { int32_t ne11; int32_t ne_12_2; // assume K and V are same shape diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e99c1763f63..50059df1ea1 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -370,6 +370,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rope(ctx, idx); } break; + case GGML_OP_ROPE_COMP: + { + n_fuse = ggml_metal_op_rope_comp(ctx, idx); + } break; case GGML_OP_IM2COL: { n_fuse = ggml_metal_op_im2col(ctx, idx); @@ -3216,6 +3220,96 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_rope_comp(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + // make sure we have one or more position id(ne10) per token(ne02) + GGML_ASSERT(ne10 % ne02 == 0); + GGML_ASSERT(ne10 >= ne02); + + const int nth = std::min(1024, ne00); + + int32_t n_dims, idx_pair, idx_scale, idx_offset; + float theta_scale, yarn_high, yarn_low, freq_scale, ramp_factor, attn_factor; + int32_t sections[4]; + + memcpy(&n_dims, (int32_t *)op->op_params + 0, sizeof(int32_t)); + memcpy(&idx_pair, (int32_t *)op->op_params + 1, sizeof(int32_t)); + memcpy(&idx_scale, (int32_t *)op->op_params + 2, sizeof(int32_t)); + memcpy(&idx_offset, (int32_t *)op->op_params + 3, sizeof(int32_t)); + memcpy(&theta_scale, (int32_t *)op->op_params + 4, sizeof(float)); + memcpy(&yarn_high, (int32_t *)op->op_params + 5, sizeof(float)); + memcpy(&yarn_low, (int32_t *)op->op_params + 6, sizeof(float)); + memcpy(&freq_scale, (int32_t *)op->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *)op->op_params + 8, sizeof(float)); + memcpy(&ramp_factor, (int32_t *)op->op_params + 9, sizeof(float)); + memcpy(§ions[0], (int32_t *)op->op_params + 10, sizeof(int32_t)); + memcpy(§ions[1], (int32_t *)op->op_params + 11, sizeof(int32_t)); + memcpy(§ions[2], (int32_t *)op->op_params + 12, sizeof(int32_t)); + memcpy(§ions[3], (int32_t *)op->op_params + 13, sizeof(int32_t)); + + ggml_metal_kargs_rope_comp args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_dims =*/ n_dims, + /*.idx_pair =*/ idx_pair, + /*.idx_scale =*/ idx_scale, + /*.idx_offset =*/ idx_offset, + /*.theta_scale =*/ theta_scale, + /*.yarn_high =*/ yarn_high, + /*.yarn_low =*/ yarn_low, + /*.freq_scale =*/ freq_scale, + /*.ramp_factor =*/ ramp_factor, + /*.attn_factor =*/ attn_factor, + /*.sect_0 =*/ sections[0], + /*.sect_1 =*/ sections[1], + /*.sect_2 =*/ sections[2], + /*.sect_3 =*/ sections[3], + /*.src2 =*/ op->src[2] != nullptr, + }; + + auto pipeline = ggml_metal_library_get_pipeline_rope_comp(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); + if (op->src[2]) { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); + } else { + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 3); + } + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 4); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + return 1; +} + int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 902b5445232..033364f8d99 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -71,6 +71,7 @@ int ggml_metal_op_l2_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_group_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_norm (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rope (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_rope_comp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_im2col (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_2d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_conv_transpose_1d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 51bcbae309f..f84eed48428 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4099,6 +4099,61 @@ static void rope_yarn_corr_dims( dims[1] = min(n_dims - 1.0f, ceil(rope_yarn_corr_factor(n_dims, n_ctx_orig, beta_slow, freq_base))); } +template +kernel void kernel_rope_comp( + constant ggml_metal_kargs_rope_comp & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + device const float * pos = (device const float *) src1; + const float p = pos[i2]; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const float freq_factor = args.src2 ? ((device const float *) src2)[i0/2] : 1.0f; + + float theta = p * pow(args.theta_scale, i0/2) / freq_factor; + const float theta_extrap = theta; + const float theta_interp = args.freq_scale * theta; + + if (args.ramp_factor != 0.0f) { + const float ramp_mix = rope_yarn_ramp(args.yarn_low, args.yarn_high, i0) * args.ramp_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + } else { + theta = theta_interp; + } + + const float cos_theta = cos(theta) * args.attn_factor; + const float sin_theta = sin(theta) * args.attn_factor; + + const int ic = i0 / args.idx_scale; + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.idx_pair]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.idx_pair] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + template kernel void kernel_rope_norm( constant ggml_metal_kargs_rope & args, @@ -4359,6 +4414,7 @@ typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; typedef decltype(kernel_rope_multi) kernel_rope_multi_t; typedef decltype(kernel_rope_vision) kernel_rope_vision_t; +typedef decltype(kernel_rope_comp) kernel_rope_comp_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -4372,6 +4428,9 @@ template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kerne template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_comp_f32")]] kernel kernel_rope_comp_t kernel_rope_comp; +template [[host_name("kernel_rope_comp_f16")]] kernel kernel_rope_comp_t kernel_rope_comp; + typedef void (im2col_t)( constant ggml_metal_kargs_im2col & args, device const float * x, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb3ae72eaac..9ff988e07a5 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -991,6 +991,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "SOFT_MAX", "SOFT_MAX_BACK", "ROPE", + "ROPE_COMP", "ROPE_BACK", "CLAMP", "CONV_TRANSPOSE_1D", @@ -1045,7 +1046,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1154,7 +1155,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 95, "GGML_OP_COUNT != 95"); +static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4265,6 +4266,107 @@ void ggml_rope_yarn_corr_dims( dims[1] = MIN(n_dims - 1, end); } +// ggml_rope_comp + +GGML_API struct ggml_tensor * ggml_rope_comp( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int32_t n_dims, + float freq_base, + enum ggml_rope_ordering ordering) { + GGML_ASSERT(ggml_is_vector(b)); + GGML_ASSERT(b->type == GGML_TYPE_F32); + + GGML_ASSERT(b->ne[0] >= a->ne[2]); // also allow M-RoPE + GGML_ASSERT(b->ne[0] % a->ne[2] == 0); + + int32_t idx_pair = 1; + int32_t idx_scale = 1; + if (ordering == GGML_ROPE_ORDERING_NEOX) { + idx_pair = n_dims / 2; + idx_scale = 2; + } + + // note: theta = theta_base * theta_scale^i + // where theta_base == the position angle (0, 1, 2, ..., n_tokens - 1) + const float theta_scale = powf(freq_base, -2.0f / (float)n_dims); + + int32_t i_zero = 0; + float f_zero = 0.0f; + float f_one = 1.0f; + + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + int32_t params[15]; + memset(params, 0, sizeof(params)); + memcpy(params + 0, &n_dims, sizeof(int32_t)); // n_dims + memcpy(params + 1, &idx_pair, sizeof(int32_t)); // idx_pair + memcpy(params + 2, &idx_scale, sizeof(int32_t)); // idx_scale + memcpy(params + 3, &i_zero, sizeof(int32_t)); // idx_offset for 2D-RoPE + memcpy(params + 4, &theta_scale, sizeof(float)); // theta_scale + memcpy(params + 5, &f_zero, sizeof(float)); // yarn_high + memcpy(params + 6, &f_zero, sizeof(float)); // yarn_low + memcpy(params + 7, &f_one, sizeof(float)); // freq_scale + memcpy(params + 8, &f_one, sizeof(float)); // attn_factor + memcpy(params + 9, &f_zero, sizeof(float)); // ramp_factor + memcpy(params + 10, &i_zero, sizeof(int32_t)); // sections[0] + memcpy(params + 11, &i_zero, sizeof(int32_t)); // sections[1] + memcpy(params + 12, &i_zero, sizeof(int32_t)); // sections[2] + memcpy(params + 13, &i_zero, sizeof(int32_t)); // sections[3] + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_ROPE_COMP; + result->src[0] = a; + result->src[1] = b; + result->src[2] = NULL; + + return result; +} + +struct ggml_tensor * ggml_rope_comp_set_freq_factors( + struct ggml_context * ctx, + struct ggml_tensor * node, + struct ggml_tensor * freq_factors) { + GGML_UNUSED(ctx); + GGML_ASSERT(node->op == GGML_OP_ROPE_COMP); + GGML_ASSERT(freq_factors->type == GGML_TYPE_F32); + + const int32_t n_dims = *((int32_t *) node->op_params + 0); + GGML_ASSERT(freq_factors->ne[0] >= n_dims / 2); + + node->src[2] = freq_factors; + + return node; +} + +struct ggml_tensor * ggml_rope_comp_set_yarn( + struct ggml_context * ctx, + struct ggml_tensor * node, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ramp_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + GGML_UNUSED(ctx); + GGML_ASSERT(node->op == GGML_OP_ROPE_COMP); + + const int32_t n_dims = *((int32_t *) node->op_params + 0); + + const float start = floorf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_fast, freq_base)); + const float end = ceilf(ggml_rope_yarn_corr_dim(n_dims, n_ctx_orig, beta_slow, freq_base)); + const float yarn_low = MAX(0, start); + const float yarn_high = MIN(n_dims - 1, end); + + memcpy((float *) node->op_params + 5, &yarn_high, sizeof(float)); + memcpy((float *) node->op_params + 6, &yarn_low, sizeof(float)); + memcpy((float *) node->op_params + 7, &freq_scale, sizeof(float)); + memcpy((float *) node->op_params + 8, &attn_factor, sizeof(float)); + memcpy((float *) node->op_params + 9, &ramp_factor, sizeof(float)); + return node; +} + // ggml_rope_back struct ggml_tensor * ggml_rope_ext_back( @@ -6848,6 +6950,7 @@ void ggml_build_backward_expand( case GGML_OP_GET_ROWS: // row indices not differentiable case GGML_OP_GET_ROWS_BACK: // same as for GET_ROWS case GGML_OP_ROPE: // positions not differentiable + case GGML_OP_ROPE_COMP: // same as for ROPE ignore_src[1] = true; break; diff --git a/tests/test-rope.cpp b/tests/test-rope.cpp index 801e4cd8270..de604ba4a41 100644 --- a/tests/test-rope.cpp +++ b/tests/test-rope.cpp @@ -124,6 +124,360 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * ggml_graph_compute(graph, &plan); } +// +// test comparing rope and rope_comp +// + +struct test_rope { + const ggml_type type; + const std::array ne_a; + int n_dims; + int mode; + int n_ctx; // used to generate positions + float fs; // freq_scale + float ef; // ext_factor + float af; // attn_factor + bool ff; + int v; // view (1 : non-contiguous a) + bool forward; // unused for now + bool inplace; + + bool use_comp = false; + + std::string vars() { + char buf[256]; + snprintf(buf, sizeof(buf), + "type=%d ne=(%lld,%lld,%lld,%lld) n_dims=%d mode=%d fs=%f ef=%f af=%f ff=%d v=%d inplace=%d", + type, ne_a[0], ne_a[1], ne_a[2], ne_a[3], n_dims, mode, fs, ef, af, ff ? 1 : 0, v, inplace ? 1 : 0); + return std::string(buf); + } + + test_rope(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {10, 5, 3, 1}, + int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f, + float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false) + : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {} + + ggml_tensor * _ggml_rope_multi( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int sections[GGML_MROPE_SECTIONS], + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + if (use_comp) { + return nullptr; + } else { + return ggml_rope_multi( + ctx, a, b, c, n_dims, sections, mode, n_ctx_orig, + freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + } + } + + struct ggml_tensor * _ggml_rope_ext( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_tensor * c, + int n_dims, + int mode, + int n_ctx_orig, + float freq_base, + float freq_scale, + float ext_factor, + float attn_factor, + float beta_fast, + float beta_slow) { + if (use_comp) { + b = ggml_cast(ctx, b, GGML_TYPE_F32); // pos must be F32 + bool is_neox = (mode == GGML_ROPE_TYPE_NEOX); + ggml_tensor * x = ggml_rope_comp( + ctx, a, b, n_dims, + freq_base, is_neox ? GGML_ROPE_ORDERING_NEOX : GGML_ROPE_ORDERING_NORMAL); + if (ext_factor != 0.0f) { + attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } + x = ggml_rope_comp_set_yarn(ctx, x, n_ctx_orig, + freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + if (freq) { + x = ggml_rope_comp_set_freq_factors(ctx, x, freq); + } + return x; + } else { + return ggml_rope_ext( + ctx, a, b, c, n_dims, mode, n_ctx_orig, + freq_base, freq_scale, ext_factor, attn_factor, + beta_fast, beta_slow); + } + } + + ggml_tensor * a = nullptr; + ggml_tensor * freq = nullptr; + ggml_tensor * pos = nullptr; + + void build_input(ggml_context * ctx) { + GGML_ASSERT(a == nullptr); + if (v & 1) { + auto ne = ne_a; ne[0] *= 2; ne[1] *= 4; ne[2] *= 3; + a = ggml_new_tensor(ctx, type, 4, ne.data()); + if (forward) { + ggml_set_param(a); + } + ggml_set_input(a); + ggml_set_name(a, "a"); + + a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view_of_a"); + } else { + a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + if (forward) { + ggml_set_param(a); + } + ggml_set_input(a); + ggml_set_name(a, "a"); + } + + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope || is_vision) { + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2] * 4); + } else { + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, ne_a[2]); + } + ggml_set_input(pos); + ggml_set_name(pos, "pos"); + + if (ff) { + freq = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_dims/2); + ggml_set_input(freq); + ggml_set_name(freq, "freq"); + } + } + + ggml_tensor * build_graph(ggml_context * ctx) { + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + ggml_tensor * out = nullptr; + if (is_mrope) { + if (is_vision) { + GGML_ASSERT(n_dims/4 > 0); + int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate + if (forward) { + if (inplace) { + //out = _ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = _ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } else { + //out = _ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } else { + GGML_ASSERT(n_dims/3 > 0); + int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; + if (forward) { + if (inplace) { + //out = _ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = _ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } else { + //out = _ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } + } else { + if (forward) { + if (inplace) { + //out = _ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = _ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } else { + //out = _ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } + } + + if (out) { + ggml_set_name(out, "out"); + } + + return out; + } + + void init_tensor_uniform(ggml_tensor * tensor, float fmin = -1.0f, float fmax = 1.0f) { + const size_t n_elements = ggml_nelements(tensor); + switch (tensor->type) { + case GGML_TYPE_F32: + { + float * data = (float *)tensor->data; + for (size_t i = 0; i < n_elements; i++) { + data[i] = frand()*(fmax - fmin) + fmin; + } + } break; + case GGML_TYPE_F16: + { + ggml_fp16_t * data = (ggml_fp16_t *)tensor->data; + for (size_t i = 0; i < n_elements; i++) { + float v = frand()*(fmax - fmin) + fmin; + data[i] = ggml_fp32_to_fp16(v); + } + } break; + default: + assert(false); + } + } + + void initialize_tensors(ggml_context * ctx) { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if ((t->flags & GGML_TENSOR_FLAG_INPUT) == 0) { + continue; + } + if (t->type == GGML_TYPE_I32) { + // pos + const int num_pos_ids = (mode & GGML_ROPE_TYPE_MROPE) ? ne_a[2] * 4 : ne_a[2]; + std::vector data(num_pos_ids); + for (int i = 0; i < num_pos_ids; i++) { + data[i] = rand() % n_ctx; + } + // printf("init pos tensor %s\n", ggml_get_name(t)); + memcpy(t->data, data.data(), num_pos_ids * sizeof(int)); + } else { + if (t->ne[0] == n_dims/2) { + // frequency factors in the range [0.9f, 1.1f] + // printf("init freq tensor %s\n", ggml_get_name(t)); + init_tensor_uniform(t, 0.9f, 1.1f); + } else { + // printf("init param tensor %s\n", ggml_get_name(t)); + init_tensor_uniform(t); + } + } + } + } +}; + +static void test_rope_comp() { + ggml_init_params params = { + /* .mem_size = */ 128*1024*1024, + /* .mem_buffer = */ NULL, + /* .no_alloc = */ false, + }; + + std::vector test_cases; + + bool all = true; + bool fw = true; + for (float fs : { 1.0f, 1.4245f }) { + for (float ef : { 0.0f, 0.7465f }) { + for (float af : { 1.0f, 1.4245f }) { + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (bool ff : {false, true}) { // freq_factors + for (float v : { 0, 1 }) { + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B + + if (all) { + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B + } + + if (all) { + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + } + + if (all) { + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_MROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,imrope (qwen3vl 7B) + test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 20, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, {128, 28, 2, 1}, 32, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) + test_cases.emplace_back(new test_rope(type, {128, 16, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen3vl) + } + + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + } + } + + all = false; + } + } + } + } + + std::vector work_buffer; + + size_t n_passed = 0; + + for (size_t i = 0; i < test_cases.size(); i++) { + test_rope * tc = test_cases[i]; + + ggml_context * ctx0 = ggml_init(params); + ggml_cgraph * gf = ggml_new_graph(ctx0); + + tc->build_input(ctx0); + tc->initialize_tensors(ctx0); + + ggml_tensor * out0 = tc->build_graph(ctx0); + tc->use_comp = true; + ggml_tensor * out1 = tc->build_graph(ctx0); + + if (out0 == nullptr || out1 == nullptr) { + GGML_PRINT("test_rope_comp \x1b[33mSKIPPED\x1b[0m: %s\n", tc->vars().c_str()); + ggml_free(ctx0); + delete tc; + continue; + } + + // calculate nmse between out0 and out1 + ggml_tensor * diff = ggml_sub(ctx0, out0, out1); + ggml_tensor * mse_a_b = ggml_sum(ctx0, ggml_sqr(ctx0, diff)); + ggml_tensor * mse_a_0 = ggml_sum(ctx0, ggml_sqr(ctx0, out0)); + ggml_tensor * out = ggml_div(ctx0, mse_a_b, mse_a_0); + out = ggml_cast(ctx0, out, GGML_TYPE_F32); + + ggml_build_forward_expand(gf, out); + ggml_graph_compute_helper(work_buffer, gf, 4); + + GGML_ASSERT(ggml_nelements(out) == 1); + float nmse = ((float *)out->data)[0]; + const float nmse_threshold = 1e-3f; + if (nmse > nmse_threshold) { + GGML_PRINT("test_rope_comp \x1b[31mFAILED\x1b[0m: nmse=%f > %f for %s\n", nmse, nmse_threshold, tc->vars().c_str()); + } else { + GGML_PRINT("test_rope_comp OK : nmse=%f <= %f for %s\n", nmse, nmse_threshold, tc->vars().c_str()); + n_passed++; + } + + ggml_free(ctx0); + delete tc; + } + + GGML_ASSERT(n_passed == test_cases.size()); +} + int main(int /*argc*/, const char ** /*argv*/) { struct ggml_init_params params = { /* .mem_size = */ 128*1024*1024, @@ -259,5 +613,7 @@ int main(int /*argc*/, const char ** /*argv*/) { ggml_free(ctx0); + test_rope_comp(); + return 0; }