@@ -605,7 +605,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
605605 blocks_->push_back (block);
606606 layers_.push_back (block);
607607 }
608- // TODO 融合算子
609608 post_layernorm_ = register_module (" post_layernorm" , Glm4VisionRmsNorm (context));
610609
611610 downsample_ = register_module (" downsample" , torch::nn::Conv2d (torch::nn::Conv2dOptions (hidden_size_, out_hidden_size_, spatial_merge_size_)
@@ -672,8 +671,6 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
672671 auto repeated = torch::repeat_interleave (h_times_w, repeats, 0 );
673672 c10::optional<torch::ScalarType> cumsum_dtype;
674673
675- LOG (INFO) << " Glm4VisionTransformerImpl repeated " << repeated;
676-
677674 cumsum_dtype = torch::kInt32 ;
678675 auto cu_seqlens = torch::cumsum (repeated, 0 , cumsum_dtype);
679676 namespace F = torch::nn::functional;
@@ -682,27 +679,21 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
682679 std::vector<int > seqlens;
683680 seqlens.assign (cu_seqlens.data_ptr <int >(),cu_seqlens.data_ptr <int >() + cu_seqlens.numel ());
684681
685- LOG (INFO) << " Glm4VisionTransformerImpl forward embedding before cu_seqlens " << cu_seqlens << " seqlens.size()" << seqlens.size ();
686682 hidden_states = embeddings_ (hidden_states, seqlens, grid_thw, image_type_ids.select (1 , 0 ), image_type_ids.select (1 , 1 ));
687- LOG (INFO) << " Glm4VisionTransformerImpl forward embedding after " ;
688683 ModelInputParams& input_params_new = const_cast <ModelInputParams&>(input_params);
689684 torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu ();
690685 std::vector<int > cu_seqlens_vec (
691- cu_seqlens_cpu.data_ptr <int >(), // full seqlen vec
686+ cu_seqlens_cpu.data_ptr <int >(),
692687 cu_seqlens_cpu.data_ptr <int >() + cu_seqlens_cpu.numel ());
688+ cu_seqlens = cu_seqlens.to (hidden_states.device ());
693689 for (int idx = 0 ; idx < blocks_->size (); ++idx) {
694- hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx); // TODO
695- LOG (INFO) << " Glm4VisionTransformerImpl forward layer " << idx;
690+ hidden_states = layers_[idx](hidden_states, m_cos, m_sin, cu_seqlens, cu_seqlens_vec, input_params_new, idx);
696691 }
697- LOG (INFO) << " Glm4VisionTransformerImpl forward layer after " ;
698692 hidden_states = post_layernorm_ (hidden_states);
699693 hidden_states = hidden_states.view ({-1 , spatial_merge_size_, spatial_merge_size_, hidden_states.size (-1 )});
700- // TO down sample merge op
701694 hidden_states = hidden_states.permute ({0 , 3 , 1 , 2 });
702695 hidden_states = downsample_ (hidden_states).view ({-1 , out_hidden_size_});
703- LOG (INFO) << " Glm4VisionTransformerImpl downsample after" ;
704696 hidden_states = merger_ (hidden_states);
705- LOG (INFO) << " Glm4VisionTransformerImpl forward end" ;
706697 return hidden_states;
707698 };
708699
@@ -820,12 +811,10 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
820811 const ModelInputParams& input_params) {
821812 auto inputs_embeds = language_model_->get_input_embeddings (input_ids);
822813 if (image_input) {
823- // visual
824814 auto image_embeds =
825815 visual_ (image_input->pixel_values .to (options_),
826816 image_input->image_grid_thw ,
827817 input_params);
828- // merge
829818 auto is_multimodal = torch::isin (input_ids,
830819 model_args_.image_token_id ()); input_params.visual_pos_masks =
831820 is_multimodal; inputs_embeds.index_put_ ({is_multimodal}, image_embeds);
@@ -851,7 +840,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
851840
852841 if (pixel_values.defined () && image_grid_thw.defined ())
853842 image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw};
854-
855843 auto inputs_embeds = get_input_embeddings (tokens, image_inputs, video_inputs, input_params);
856844 input_params.input_embedding = inputs_embeds;
857845 auto emb = language_model_ (tokens, positions, kv_caches, input_params);
@@ -869,7 +857,6 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
869857 visual_->load_state_dict (
870858 state_dict->get_dict_with_prefix (" model.visual." ));
871859 }
872- // verify
873860 visual_->verify_loaded_weights (" model.visual." );
874861 visual_->merge_loaded_weights ();
875862 if (!model_args_.image_embedding_mode ()) {
0 commit comments