Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion video_chat/models/videochat_it.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,10 @@ def forward(self, image, text_input):
attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device)
targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100)
# set bos_token
inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id
bos_embeds = self.llama_model.model.embed_tokens(
torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device)
)
inputs_embeds[:, :1] = bos_embeds
for idx in range(batch_size):
input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
# if less than txt_len, the input will be padding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,12 @@ def forward(self, image, text_input, instruction):
attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(device)
targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(device).fill_(-100)
# set bos_token
inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id
bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device)
if self.use_lora:
bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id)
else:
bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id)
inputs_embeds[:, :1] = bos_embeds

for idx in range(batch_size):
input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
Expand Down
7 changes: 6 additions & 1 deletion video_chat2/models/videochat_mistra/videochat2_it_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,12 @@ def forward(self, image, text_input, instruction):
attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device)
targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100)
# set bos_token
inputs_embeds[:, :1] = self.mistral_tokenizer.bos_token_id
bos_token_id = torch.tensor([[self.mistral_tokenizer.bos_token_id]], device=inputs_embeds.device)
if self.use_lora:
bos_embeds = self.mistral_model.base_model.model.model.embed_tokens(bos_token_id)
else:
bos_embeds = self.mistral_model.model.embed_tokens(bos_token_id)
inputs_embeds[:, :1] = bos_embeds
for idx in range(batch_size):
input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
# if less than txt_len, the input will be padding
Expand Down
7 changes: 6 additions & 1 deletion video_chat2/models/videochat_vicuna/videochat2_it_vicuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,12 @@ def forward(self, image, text_input, instruction):
attention_mask = torch.zeros([batch_size, txt_len], dtype=torch.long).to(img_embeds.device)
targets = torch.ones([batch_size, txt_len], dtype=torch.long).to(img_embeds.device).fill_(-100)
# set bos_token
inputs_embeds[:, :1] = self.llama_tokenizer.bos_token_id
bos_token_id = torch.tensor([[self.llama_tokenizer.bos_token_id]], device=inputs_embeds.device)
if self.use_lora:
bos_embeds = self.llama_model.base_model.model.model.embed_tokens(bos_token_id)
else:
bos_embeds = self.llama_model.model.embed_tokens(bos_token_id)
inputs_embeds[:, :1] = bos_embeds
for idx in range(batch_size):
input_len = min(input_embed_list[idx].shape[1], txt_len - 1)
# if less than txt_len, the input will be padding
Expand Down