From 8293adc1f08a54e0bc7c8c26892e4a9275e0fc17 Mon Sep 17 00:00:00 2001 From: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> Date: Wed, 11 Feb 2026 19:34:39 +0530 Subject: [PATCH] Fix BOS token embedding: use proper embedding lookup instead of raw integer In the forward() methods, the BOS token position in inputs_embeds was being set to the raw integer bos_token_id (e.g. 1) instead of the actual embedding vector. Since inputs_embeds is a float tensor produced by embed_tokens(), assigning a scalar integer broadcasts that value across the entire embedding dimension, corrupting the BOS representation. Fix by passing bos_token_id through embed_tokens() to obtain the correct embedding vector before assignment. This is applied consistently across all four model files, respecting each file's model attribute path and LoRA configuration. Affected files: - video_chat/models/videochat_it.py - video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py - video_chat2/models/videochat_mistra/videochat2_it_mistral.py - video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py --- video_chat/models/videochat_it.py | 5 ++++- .../models/videochat_mistra/videochat2_it_hd_mistral.py | 7 ++++++- .../models/videochat_mistra/videochat2_it_mistral.py | 7 ++++++- .../models/videochat_vicuna/videochat2_it_vicuna.py | 7 ++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/video_chat/models/videochat_it.py b/video_chat/models/videochat_it.py index b0beb9b..efb059b 100644 --- a/video_chat/models/videochat_it.py +++ b/video_chat/models/videochat_it.py @@ -265,7 +265,10 @@ def forward(self, image, text_input): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens( + torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device) + ) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding diff --git a/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py b/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py index c66e0e3..a447710 100644 --- a/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py +++ b/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py @@ -368,7 +368,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) diff --git a/video_chat2/models/videochat_mistra/videochat2_it_mistral.py b/video_chat2/models/videochat_mistra/videochat2_it_mistral.py index df325de..9806e13 100644 --- a/video_chat2/models/videochat_mistra/videochat2_it_mistral.py +++ b/video_chat2/models/videochat_mistra/videochat2_it_mistral.py @@ -299,7 +299,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding diff --git a/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py b/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py index bb6d7e0..0fc1f8f 100644 --- a/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py +++ b/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py @@ -296,7 +296,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.llama_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.llama_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding