Skip to content

Commit 7ee36cc

Browse files
ext.xingsilan1DongheJin
authored andcommitted
fix: cos sin errror
1 parent 71add7e commit 7ee36cc

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

xllm/models/vlm/glm4v.h

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
363363
}
364364
torch::Tensor forward(
365365
torch::Tensor x,
366-
std::vector<int64_t> lengths,
366+
std::vector<int> lengths,
367367
torch::Tensor image_shapes,
368368
torch::Tensor h_coords,
369369
torch::Tensor w_coords
@@ -399,12 +399,12 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
399399
std::vector<torch::Tensor> target_w_list;
400400
target_h_list.reserve(batch_size);
401401
target_w_list.reserve(batch_size);
402-
402+
LOG(INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size: " << batch_size << "image_shapes " << image_shapes;
403403
for (int64_t i = 0; i < batch_size; ++i) {
404404
const int64_t seq_len = lengths[i];
405405
const auto img_h = image_shapes.index({i, 1}).to(torch::kFloat32);
406406
const auto img_w = image_shapes.index({i, 2}).to(torch::kFloat32);
407-
407+
LOG(INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size idx " << i;
408408
target_h_list.push_back(img_h.repeat({seq_len}));
409409
target_w_list.push_back(img_w.repeat({seq_len}));
410410
}
@@ -417,16 +417,17 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
417417

418418
const auto norm_w = ((w_coords_fp32 + 0.5f) / target_w) * 2.0f - 1.0f;
419419
const auto norm_h = ((h_coords_fp32 + 0.5f) / target_h) * 2.0f - 1.0f;
420-
420+
LOG(INFO) << " Glm4vVisionEmbeddingsImpl stack";
421421
auto grid = torch::stack({norm_w, norm_h}, -1)
422422
.unsqueeze(0)
423423
.unsqueeze(2);
424-
424+
LOG(INFO) << " Glm4vVisionEmbeddingsImpl stack after";
425425
namespace F = torch::nn::functional;
426426
auto interpolated_embed = F::grid_sample(
427427
pos_embed_2d,
428428
grid,
429429
F::GridSampleFuncOptions().mode(torch::kBicubic).padding_mode(torch::kBorder).align_corners(false));
430+
LOG(INFO) << " Glm4vVisionEmbeddingsImpl interpolated_embed";
430431
adapted_pos_embed = interpolated_embed
431432
.squeeze(0)
432433
.squeeze(-1)
@@ -656,34 +657,32 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
656657
auto emb = torch::cat({rotary_pos_emb, rotary_pos_emb}, -1);
657658
auto m_cos = emb.cos();
658659
auto m_sin = emb.sin();
660+
LOG(INFO) << " Glm4VisionTransformerImpl" << " numel=" << grid_thw.numel() << " min=" << grid_thw.min().item<float>();
661+
LOG(INFO) << grid_thw;
659662

660663
auto device = grid_thw.device();
661-
auto grid_t = grid_thw.index_select(1, torch::tensor({0}, torch::TensorOptions().dtype(torch::kLong).device(device)));
662-
auto grid_h = grid_thw.index_select(1, torch::tensor({1}, torch::TensorOptions().dtype(torch::kLong).device(device)));
663-
auto grid_w = grid_thw.index_select(1, torch::tensor({2}, torch::TensorOptions().dtype(torch::kLong).device(device)));
664+
auto grid_t = grid_thw.index_select(1, torch::tensor({0}, torch::TensorOptions().dtype(torch::kInt).device(device)));
665+
auto grid_h = grid_thw.index_select(1, torch::tensor({1}, torch::TensorOptions().dtype(torch::kInt).device(device)));
666+
auto grid_w = grid_thw.index_select(1, torch::tensor({2}, torch::TensorOptions().dtype(torch::kInt).device(device)));
664667
auto h_times_w = (grid_h * grid_w).squeeze(1);
665668
auto repeats = grid_t.squeeze(1);
666669
auto repeated = torch::repeat_interleave(h_times_w, repeats, 0);
667670
c10::optional<torch::ScalarType> cumsum_dtype;
668671

672+
LOG(INFO) << " Glm4VisionTransformerImpl repeated " << repeated;
673+
669674
cumsum_dtype = torch::kInt32;
670675
auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype);
671676
namespace F = torch::nn::functional;
672-
cu_seqlens = F::pad(
673-
cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0));
674-
cu_seqlens = torch::diff(cu_seqlens);
675-
torch::Tensor cu_seqlens_slice1 = cu_seqlens.narrow(0, 1, cu_seqlens.size(0) - 1);
676-
torch::Tensor cu_seqlens_slice0 = cu_seqlens.narrow(0, 0, cu_seqlens.size(0) - 1);
677-
torch::Tensor seqlens_tensor = cu_seqlens_slice1 - cu_seqlens_slice0;
678-
std::vector<int64_t> seqlens;
679-
seqlens.assign(
680-
seqlens_tensor.cpu().to(torch::kLong).data_ptr<int64_t>(),
681-
seqlens_tensor.cpu().to(torch::kLong).data_ptr<int64_t>() + seqlens_tensor.numel()
682-
);
683-
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding before ";
677+
cu_seqlens = F::pad(cu_seqlens, F::PadFuncOptions({1, 0}).mode(torch::kConstant).value(0));
678+
cu_seqlens = torch::diff(cu_seqlens).cpu().to(torch::kInt);
679+
std::vector<int> seqlens;
680+
seqlens.assign(cu_seqlens.data_ptr<int>(),cu_seqlens.data_ptr<int>() + cu_seqlens.numel());
681+
682+
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding before cu_seqlens " << cu_seqlens << "seqlens.size()" << seqlens.size();
684683
hidden_states = embeddings_(hidden_states, seqlens, grid_thw, image_type_ids.select(1, 0), image_type_ids.select(1, 1));
685-
ModelInputParams& input_params_new =
686-
const_cast<ModelInputParams&>(input_params);
684+
LOG(INFO) << " Glm4VisionTransformerImpl forward embedding after ";
685+
ModelInputParams& input_params_new = const_cast<ModelInputParams&>(input_params);
687686
torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu();
688687
std::vector<int> cu_seqlens_vec(
689688
cu_seqlens_cpu.data_ptr<int>(), // full seqlen vec

xllm/models/vlm/glm4v_moe.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
6060
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings pixel aft ";
6161
auto image_embeds =
6262
visual_(pixel,
63-
image_input->image_grid_thw,
63+
image_input->image_grid_thw.to(pixel.device()),
6464
input_params);
6565
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ end ";
6666
auto inputs_embeds = language_model_->get_input_embeddings(input_ids);

0 commit comments

Comments
 (0)