@@ -42,6 +42,7 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
4242 Glm4vMoeForConditionalGenerationImpl (const ModelContext& context)
4343 : model_args_(context.get_model_args()),
4444 options_ (context.get_tensor_options()) {
45+ std::cout << " ----------------Glm4vMoeForConditionalGenerationImpl init begin ----------------- " << std::endl;
4546 visual_ = register_module (" visual" , Glm4VisionTransformer (context));
4647
4748 language_model_ =
@@ -53,25 +54,28 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
5354 const std::optional<Glm4VImageInputs>& image_input,
5455 const std::optional<Glm4VVideoInputs>& video_input,
5556 const ModelInputParams& input_params) {
57+ // visual
58+ LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ begin " ;
59+ torch::Tensor pixel = image_input->pixel_values .to (options_);
60+ LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings pixel aft " ;
61+ auto image_embeds =
62+ visual_ (pixel,
63+ image_input->image_grid_thw ,
64+ input_params);
65+ LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings visual_ end " ;
5666 auto inputs_embeds = language_model_->get_input_embeddings (input_ids);
57- if (image_input) {
58- // visual
59- auto image_embeds =
60- visual_ (image_input->pixel_values .to (options_),
61- image_input->image_grid_thw ,
62- input_params);
63- // merge
64- auto is_multimodal = torch::isin (input_ids,
65- model_args_.image_token_id ()); input_params.visual_pos_masks =
66- is_multimodal; inputs_embeds.index_put_ ({is_multimodal}, image_embeds);
67- }
67+ // merge
68+ auto is_multimodal = torch::isin (input_ids,
69+ model_args_.image_token_id ()); input_params.visual_pos_masks =
70+ is_multimodal; inputs_embeds.index_put_ ({is_multimodal}, image_embeds);
6871 return inputs_embeds;
6972 }
7073
7174 torch::Tensor forward (const torch::Tensor& tokens,
7275 const torch::Tensor& positions,
7376 std::vector<KVCache>& kv_caches,
7477 const ModelInputParams& input_params) {
78+ std::cout << " ----------------Glm4vMoeForConditionalGenerationImpl beging ----------------- " << std::endl;
7579 LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl beging " ;
7680 torch::NoGradGuard no_grad;
7781 const auto & mm_data = input_params.mm_data ;
@@ -87,6 +91,8 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
8791 LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl forward Glm4VImageInputs beging " ;
8892 if (pixel_values.defined () && image_grid_thw.defined ())
8993 image_inputs = Glm4VImageInputs{pixel_values, image_grid_thw};
94+ else
95+ LOG (FATAL) << " Pixel value or image grid thw is null" ;
9096
9197 LOG (INFO) << " Glm4vMoeForConditionalGenerationImpl forward get_input_embeddings beging " ;
9298
0 commit comments