Skip to content

Commit 6cda3ee

Browse files
ext.xingsilan1DongheJin
authored andcommitted
feat: vit forward run success
1 parent 7ee36cc commit 6cda3ee

File tree

1 file changed

+16
-10
lines changed

1 file changed

+16
-10
lines changed

xllm/models/vlm/glm4v.h

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -464,34 +464,36 @@ class Glm4_VisionPatchMergerImpl : public torch::nn::Module {
464464
public:
465465
Glm4_VisionPatchMergerImpl(const ModelContext& context) {
466466
auto model_args = context.get_model_args();
467-
auto options = context.get_tensor_options();
467+
options_ = context.get_tensor_options();
468468
auto parallel_args = context.get_parallel_args();
469469
int64_t dim = model_args.mm_projection_dim();
470470
int64_t context_dim = model_args.mm_intermediate_size();
471471
norm_ = register_module("norm", torch::nn::LayerNorm(torch::nn::LayerNormOptions({dim})));
472-
norm_->weight.set_data(norm_->weight.to(options));
473-
norm_->bias.set_data(norm_->bias.to(options));
472+
norm_->weight.set_data(norm_->weight.to(options_));
473+
norm_->bias.set_data(norm_->bias.to(options_));
474474
proj_ = register_module(
475475
"proj",
476476
torch::nn::Linear(torch::nn::LinearOptions(dim, dim).bias(false)));
477-
477+
proj_->weight.set_data(proj_->weight.to(options_));
478478
act_ = register_module("act", torch::nn::GELU());
479479
silu_ = register_module("silu", torch::nn::SiLU());
480480

481481
gate_ = register_module(
482482
"gate",
483483
torch::nn::Linear(torch::nn::LinearOptions(dim, context_dim).bias(false)));
484-
484+
gate_->weight.set_data(gate_->weight.to(options_));
485485
up_ = register_module(
486486
"up",
487487
torch::nn::Linear(torch::nn::LinearOptions(dim, context_dim).bias(false)));
488-
488+
up_->weight.set_data(up_->weight.to(options_));
489489
down_ = register_module(
490490
"down",
491491
torch::nn::Linear(torch::nn::LinearOptions(context_dim, dim).bias(false)));
492+
down_->weight.set_data(down_->weight.to(options_));
492493
}
493494

494495
torch::Tensor forward(torch::Tensor x) {
496+
LOG(INFO) << " Glm4_VisionPatchMergerImpl forward beging " << x.device() << "options_.device() : " << options_.device();
495497
x = proj_(x);
496498
x = act_(norm_(x));
497499
x = down_(torch::mul(silu_((gate_(x))), up_(x)));
@@ -568,6 +570,7 @@ class Glm4_VisionPatchMergerImpl : public torch::nn::Module {
568570
torch::nn::Linear down_{nullptr};
569571
torch::nn::GELU act_{nullptr};
570572
torch::nn::SiLU silu_{nullptr};
573+
torch::TensorOptions options_;
571574

572575

573576
bool is_proj_weight_loaded = false;
@@ -607,6 +610,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
607610

608611
downsample_ = register_module("downsample", torch::nn::Conv2d(torch::nn::Conv2dOptions(hidden_size_, out_hidden_size_, spatial_merge_size_)
609612
.stride(spatial_merge_size_).bias(true).padding(0)));
613+
downsample_->weight.set_data(downsample_->weight.to(options_));
614+
downsample_->bias.set_data(downsample_->bias.to(options_));
610615
merger_ = register_module("merger", Glm4_VisionPatchMerger(context));
611616

612617
}
@@ -655,10 +660,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
655660

656661
auto [rotary_pos_emb, image_type_ids] = rot_pos_emb(grid_thw);
657662
auto emb = torch::cat({rotary_pos_emb, rotary_pos_emb}, -1);
658-
auto m_cos = emb.cos();
659-
auto m_sin = emb.sin();
660-
LOG(INFO) << " Glm4VisionTransformerImpl" << " numel=" << grid_thw.numel() << " min=" << grid_thw.min().item<float>();
661-
LOG(INFO) << grid_thw;
663+
auto m_cos = emb.cos().type_as(hidden_states);
664+
auto m_sin = emb.sin().type_as(hidden_states);
662665

663666
auto device = grid_thw.device();
664667
auto grid_t = grid_thw.index_select(1, torch::tensor({0}, torch::TensorOptions().dtype(torch::kInt).device(device)));
@@ -689,6 +692,7 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
689692
cu_seqlens_cpu.data_ptr<int>() + cu_seqlens_cpu.numel());
690693
for (int idx = 0; idx < blocks_->size(); ++idx) {
691694
hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); //TODO
695+
LOG(INFO) << " Glm4VisionTransformerImpl forward layer "<< idx;
692696
}
693697
LOG(INFO) << " Glm4VisionTransformerImpl forward layer after ";
694698
hidden_states = post_layernorm_(hidden_states);
@@ -782,6 +786,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
782786
bool is_post_layernorm_weight_loaded = false;
783787
bool is_downsample_weight_loaded_ = false;
784788
bool is_downsample_bias_loaded_ = false;
789+
torch::Tensor m_cos;
790+
torch::Tensor m_sin;
785791
};
786792
TORCH_MODULE(Glm4VisionTransformer);
787793

0 commit comments

Comments
 (0)