Skip to content

Commit 71add7e

Browse files
ext.xingsilan1DongheJin
authored andcommitted
bug fix
1 parent dc7633e commit 71add7e

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

xllm/models/vlm/glm4v.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -648,9 +648,9 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
648648
const ModelInputParams& input_params) {
649649
LOG(INFO) << " Glm4VisionTransformerImpl forward beging ";
650650
hidden_states = patch_embed_(hidden_states);
651-
// at_npu::native::custom_ops::npu_rms_norm()
651+
LOG(INFO) << " Glm4VisionTransformerImpl patch_embed_ beging ";
652652
hidden_states = post_conv_layernorm_(hidden_states);
653-
// hidden_states = at_npu::native::custom_ops::npu_rms_norm(hidden_states);
653+
LOG(INFO) << " Glm4VisionTransformerImpl post_conv_layernorm_ beging ";
654654

655655
auto [rotary_pos_emb, image_type_ids] = rot_pos_emb(grid_thw);
656656
auto emb = torch::cat({rotary_pos_emb, rotary_pos_emb}, -1);
@@ -665,9 +665,7 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
665665
auto repeats = grid_t.squeeze(1);
666666
auto repeated = torch::repeat_interleave(h_times_w, repeats, 0);
667667
c10::optional<torch::ScalarType> cumsum_dtype;
668-
// if (torch::jit::is_tracing()) {
669-
// cumsum_dtype = grid_thw.scalar_type();
670-
// } else {
668+
671669
cumsum_dtype = torch::kInt32;
672670
auto cu_seqlens = torch::cumsum(repeated, 0, cumsum_dtype);
673671
namespace F = torch::nn::functional;

xllm/models/vlm/glm4v_moe.h

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
4242
Glm4vMoeForConditionalGenerationImpl(const ModelContext& context)
4343
: model_args_(context.get_model_args()),
4444
options_(context.get_tensor_options()) {
45+
std::cout << "----------------Glm4vMoeForConditionalGenerationImpl init begin ----------------- " << std::endl;
4546
visual_ = register_module("visual", Glm4VisionTransformer(context));
4647

4748
language_model_ =
@@ -53,25 +54,28 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
5354
const std::optional<Glm4VImageInputs>& image_input,
5455
const std::optional<Glm4VVideoInputs>& video_input,
5556
const ModelInputParams& input_params) {
57+
// visual
58+
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ begin ";
59+
torch::Tensor pixel = image_input->pixel_values.to(options_);
60+
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings pixel aft ";
61+
auto image_embeds =
62+
visual_(pixel,
63+
image_input->image_grid_thw,
64+
input_params);
65+
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ end ";
5666
auto inputs_embeds = language_model_->get_input_embeddings(input_ids);
57-
if (image_input) {
58-
// visual
59-
auto image_embeds =
60-
visual_(image_input->pixel_values.to(options_),
61-
image_input->image_grid_thw,
62-
input_params);
63-
// merge
64-
auto is_multimodal = torch::isin(input_ids,
65-
model_args_.image_token_id()); input_params.visual_pos_masks =
66-
is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds);
67-
}
67+
// merge
68+
auto is_multimodal = torch::isin(input_ids,
69+
model_args_.image_token_id()); input_params.visual_pos_masks =
70+
is_multimodal; inputs_embeds.index_put_({is_multimodal}, image_embeds);
6871
return inputs_embeds;
6972
}
7073

7174
torch::Tensor forward(const torch::Tensor& tokens,
7275
const torch::Tensor& positions,
7376
std::vector<KVCache>& kv_caches,
7477
const ModelInputParams& input_params) {
78+
std::cout << "----------------Glm4vMoeForConditionalGenerationImpl beging ----------------- " << std::endl;
7579
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl beging ";
7680
torch::NoGradGuard no_grad;
7781
const auto& mm_data = input_params.mm_data;
@@ -87,6 +91,8 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
8791
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward Glm4VImageInputs beging ";
8892
if (pixel_values.defined() && image_grid_thw.defined())
8993
image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw};
94+
else
95+
LOG(FATAL) << "Pixel value or image grid thw is null";
9096

9197
LOG(INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings beging ";
9298

0 commit comments

Comments
 (0)