diff --git a/video_chat/models/videochat_it.py b/video_chat/models/videochat_it.py index b0beb9b..efb059b 100644 --- a/video_chat/models/videochat_it.py +++ b/video_chat/models/videochat_it.py @@ -265,7 +265,10 @@ def forward(self, image, text_input): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id + bos_embeds = self.llama_model.model.embed_tokens( + torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device) + ) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding diff --git a/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py b/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py index c66e0e3..a447710 100644 --- a/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py +++ b/video_chat2/models/videochat_mistra/videochat2_it_hd_mistral.py @@ -368,7 +368,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) diff --git a/video_chat2/models/videochat_mistra/videochat2_it_mistral.py b/video_chat2/models/videochat_mistra/videochat2_it_mistral.py index df325de..9806e13 100644 --- a/video_chat2/models/videochat_mistra/videochat2_it_mistral.py +++ b/video_chat2/models/videochat_mistra/videochat2_it_mistral.py @@ -299,7 +299,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding diff --git a/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py b/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py index bb6d7e0..0fc1f8f 100644 --- a/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py +++ b/video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py @@ -296,7 +296,12 @@ def forward(self, image, text_input, instruction): attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device) targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100) # set bos_token - inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id + bos_token_id = torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device) + if self.use_lora: + bos_embeds = self.llama_model.base_model.model.model.embed_tokens(bos_token_id) + else: + bos_embeds = self.llama_model.model.embed_tokens(bos_token_id) + inputs_embeds[:, :1] = bos_embeds for idx in range(batch_size): input_len = min(input_embed_list[idx].shape[1], txt_len - 1) # if less than txt_len, the input will be padding