@@ -464,34 +464,36 @@ class Glm4_VisionPatchMergerImpl : public torch::nn::Module {
464464 public:
465465 Glm4_VisionPatchMergerImpl (const ModelContext& context) {
466466 auto model_args = context.get_model_args ();
467- auto options = context.get_tensor_options ();
467+ options_ = context.get_tensor_options ();
468468 auto parallel_args = context.get_parallel_args ();
469469 int64_t dim = model_args.mm_projection_dim ();
470470 int64_t context_dim = model_args.mm_intermediate_size ();
471471 norm_ = register_module (" norm" , torch::nn::LayerNorm (torch::nn::LayerNormOptions ({dim})));
472- norm_->weight .set_data (norm_->weight .to (options ));
473- norm_->bias .set_data (norm_->bias .to (options ));
472+ norm_->weight .set_data (norm_->weight .to (options_ ));
473+ norm_->bias .set_data (norm_->bias .to (options_ ));
474474 proj_ = register_module (
475475 " proj" ,
476476 torch::nn::Linear (torch::nn::LinearOptions (dim, dim).bias (false )));
477-
477+ proj_-> weight . set_data (proj_-> weight . to (options_));
478478 act_ = register_module (" act" , torch::nn::GELU ());
479479 silu_ = register_module (" silu" , torch::nn::SiLU ());
480480
481481 gate_ = register_module (
482482 " gate" ,
483483 torch::nn::Linear (torch::nn::LinearOptions (dim, context_dim).bias (false )));
484-
484+ gate_-> weight . set_data (gate_-> weight . to (options_));
485485 up_ = register_module (
486486 " up" ,
487487 torch::nn::Linear (torch::nn::LinearOptions (dim, context_dim).bias (false )));
488-
488+ up_-> weight . set_data (up_-> weight . to (options_));
489489 down_ = register_module (
490490 " down" ,
491491 torch::nn::Linear (torch::nn::LinearOptions (context_dim, dim).bias (false )));
492+ down_->weight .set_data (down_->weight .to (options_));
492493 }
493494
494495 torch::Tensor forward (torch::Tensor x) {
496+ LOG (INFO) << " Glm4_VisionPatchMergerImpl forward beging " << x.device () << " options_.device() : " << options_.device ();
495497 x = proj_ (x);
496498 x = act_ (norm_ (x));
497499 x = down_ (torch::mul (silu_ ((gate_ (x))), up_ (x)));
@@ -568,6 +570,7 @@ class Glm4_VisionPatchMergerImpl : public torch::nn::Module {
568570 torch::nn::Linear down_{nullptr };
569571 torch::nn::GELU act_{nullptr };
570572 torch::nn::SiLU silu_{nullptr };
573+ torch::TensorOptions options_;
571574
572575
573576 bool is_proj_weight_loaded = false ;
@@ -607,6 +610,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
607610
608611 downsample_ = register_module (" downsample" , torch::nn::Conv2d (torch::nn::Conv2dOptions (hidden_size_, out_hidden_size_, spatial_merge_size_)
609612 .stride (spatial_merge_size_).bias (true ).padding (0 )));
613+ downsample_->weight .set_data (downsample_->weight .to (options_));
614+ downsample_->bias .set_data (downsample_->bias .to (options_));
610615 merger_ = register_module (" merger" , Glm4_VisionPatchMerger (context));
611616
612617 }
@@ -655,10 +660,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
655660
656661 auto [rotary_pos_emb, image_type_ids] = rot_pos_emb (grid_thw);
657662 auto emb = torch::cat ({rotary_pos_emb, rotary_pos_emb}, -1 );
658- auto m_cos = emb.cos ();
659- auto m_sin = emb.sin ();
660- LOG (INFO) << " Glm4VisionTransformerImpl" << " numel=" << grid_thw.numel () << " min=" << grid_thw.min ().item <float >();
661- LOG (INFO) << grid_thw;
663+ auto m_cos = emb.cos ().type_as (hidden_states);
664+ auto m_sin = emb.sin ().type_as (hidden_states);
662665
663666 auto device = grid_thw.device ();
664667 auto grid_t = grid_thw.index_select (1 , torch::tensor ({0 }, torch::TensorOptions ().dtype (torch::kInt ).device (device)));
@@ -689,6 +692,7 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
689692 cu_seqlens_cpu.data_ptr <int >() + cu_seqlens_cpu.numel ());
690693 for (int idx = 0 ; idx < blocks_->size (); ++idx) {
691694 hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); // TODO
695+ LOG (INFO) << " Glm4VisionTransformerImpl forward layer " << idx;
692696 }
693697 LOG (INFO) << " Glm4VisionTransformerImpl forward layer after " ;
694698 hidden_states = post_layernorm_ (hidden_states);
@@ -782,6 +786,8 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
782786 bool is_post_layernorm_weight_loaded = false ;
783787 bool is_downsample_weight_loaded_ = false ;
784788 bool is_downsample_bias_loaded_ = false ;
789+ torch::Tensor m_cos;
790+ torch::Tensor m_sin;
785791};
786792TORCH_MODULE (Glm4VisionTransformer);
787793
0 commit comments