@@ -363,7 +363,7 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
363363 }
364364 torch::Tensor forward (
365365 torch::Tensor x,
366- std::vector<int64_t > lengths,
366+ std::vector<int > lengths,
367367 torch::Tensor image_shapes,
368368 torch::Tensor h_coords,
369369 torch::Tensor w_coords
@@ -399,12 +399,12 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
399399 std::vector<torch::Tensor> target_w_list;
400400 target_h_list.reserve (batch_size);
401401 target_w_list.reserve (batch_size);
402-
402+ LOG (INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size: " << batch_size << " image_shapes " << image_shapes;
403403 for (int64_t i = 0 ; i < batch_size; ++i) {
404404 const int64_t seq_len = lengths[i];
405405 const auto img_h = image_shapes.index ({i, 1 }).to (torch::kFloat32 );
406406 const auto img_w = image_shapes.index ({i, 2 }).to (torch::kFloat32 );
407-
407+ LOG (INFO) << " Glm4vVisionEmbeddingsImpl forward batch_size idx " << i;
408408 target_h_list.push_back (img_h.repeat ({seq_len}));
409409 target_w_list.push_back (img_w.repeat ({seq_len}));
410410 }
@@ -417,16 +417,17 @@ class Glm4vVisionEmbeddingsImpl : public torch::nn::Module {
417417
418418 const auto norm_w = ((w_coords_fp32 + 0 .5f ) / target_w) * 2 .0f - 1 .0f ;
419419 const auto norm_h = ((h_coords_fp32 + 0 .5f ) / target_h) * 2 .0f - 1 .0f ;
420-
420+ LOG (INFO) << " Glm4vVisionEmbeddingsImpl stack " ;
421421 auto grid = torch::stack ({norm_w, norm_h}, -1 )
422422 .unsqueeze (0 )
423423 .unsqueeze (2 );
424-
424+ LOG (INFO) << " Glm4vVisionEmbeddingsImpl stack after " ;
425425 namespace F = torch::nn::functional;
426426 auto interpolated_embed = F::grid_sample (
427427 pos_embed_2d,
428428 grid,
429429 F::GridSampleFuncOptions ().mode (torch::kBicubic ).padding_mode (torch::kBorder ).align_corners (false ));
430+ LOG (INFO) << " Glm4vVisionEmbeddingsImpl interpolated_embed" ;
430431 adapted_pos_embed = interpolated_embed
431432 .squeeze (0 )
432433 .squeeze (-1 )
@@ -656,34 +657,32 @@ class Glm4VisionTransformerImpl : public torch::nn::Module {
656657 auto emb = torch::cat ({rotary_pos_emb, rotary_pos_emb}, -1 );
657658 auto m_cos = emb.cos ();
658659 auto m_sin = emb.sin ();
660+ LOG (INFO) << " Glm4VisionTransformerImpl" << " numel=" << grid_thw.numel () << " min=" << grid_thw.min ().item <float >();
661+ LOG (INFO) << grid_thw;
659662
660663 auto device = grid_thw.device ();
661- auto grid_t = grid_thw.index_select (1 , torch::tensor ({0 }, torch::TensorOptions ().dtype (torch::kLong ).device (device)));
662- auto grid_h = grid_thw.index_select (1 , torch::tensor ({1 }, torch::TensorOptions ().dtype (torch::kLong ).device (device)));
663- auto grid_w = grid_thw.index_select (1 , torch::tensor ({2 }, torch::TensorOptions ().dtype (torch::kLong ).device (device)));
664+ auto grid_t = grid_thw.index_select (1 , torch::tensor ({0 }, torch::TensorOptions ().dtype (torch::kInt ).device (device)));
665+ auto grid_h = grid_thw.index_select (1 , torch::tensor ({1 }, torch::TensorOptions ().dtype (torch::kInt ).device (device)));
666+ auto grid_w = grid_thw.index_select (1 , torch::tensor ({2 }, torch::TensorOptions ().dtype (torch::kInt ).device (device)));
664667 auto h_times_w = (grid_h * grid_w).squeeze (1 );
665668 auto repeats = grid_t .squeeze (1 );
666669 auto repeated = torch::repeat_interleave (h_times_w, repeats, 0 );
667670 c10::optional<torch::ScalarType> cumsum_dtype;
668671
672+ LOG (INFO) << " Glm4VisionTransformerImpl repeated " << repeated;
673+
669674 cumsum_dtype = torch::kInt32 ;
670675 auto cu_seqlens = torch::cumsum (repeated, 0 , cumsum_dtype);
671676 namespace F = torch::nn::functional;
672- cu_seqlens = F::pad (
673- cu_seqlens, F::PadFuncOptions ({1 , 0 }).mode (torch::kConstant ).value (0 ));
674- cu_seqlens = torch::diff (cu_seqlens);
675- torch::Tensor cu_seqlens_slice1 = cu_seqlens.narrow (0 , 1 , cu_seqlens.size (0 ) - 1 );
676- torch::Tensor cu_seqlens_slice0 = cu_seqlens.narrow (0 , 0 , cu_seqlens.size (0 ) - 1 );
677- torch::Tensor seqlens_tensor = cu_seqlens_slice1 - cu_seqlens_slice0;
678- std::vector<int64_t > seqlens;
679- seqlens.assign (
680- seqlens_tensor.cpu ().to (torch::kLong ).data_ptr <int64_t >(),
681- seqlens_tensor.cpu ().to (torch::kLong ).data_ptr <int64_t >() + seqlens_tensor.numel ()
682- );
683- LOG (INFO) << " Glm4VisionTransformerImpl forward embedding before " ;
677+ cu_seqlens = F::pad (cu_seqlens, F::PadFuncOptions ({1 , 0 }).mode (torch::kConstant ).value (0 ));
678+ cu_seqlens = torch::diff (cu_seqlens).cpu ().to (torch::kInt );
679+ std::vector<int > seqlens;
680+ seqlens.assign (cu_seqlens.data_ptr <int >(),cu_seqlens.data_ptr <int >() + cu_seqlens.numel ());
681+
682+ LOG (INFO) << " Glm4VisionTransformerImpl forward embedding before cu_seqlens " << cu_seqlens << " seqlens.size()" << seqlens.size ();
684683 hidden_states = embeddings_ (hidden_states, seqlens, grid_thw, image_type_ids.select (1 , 0 ), image_type_ids.select (1 , 1 ));
685- ModelInputParams& input_params_new =
686- const_cast <ModelInputParams&>(input_params);
684+ LOG (INFO) << " Glm4VisionTransformerImpl forward embedding after " ;
685+ ModelInputParams& input_params_new = const_cast <ModelInputParams&>(input_params);
687686 torch::Tensor cu_seqlens_cpu = cu_seqlens.cpu ();
688687 std::vector<int > cu_seqlens_vec (
689688 cu_seqlens_cpu.data_ptr <int >(), // full seqlen vec
0 commit comments