Skip to content

Tool Outputs in Attention Mask for Log Probabilities #7

@michalsr

Description

@michalsr

From my understanding, the attention mask that is passed into https://github.com/Time-Search/TimeSearch-R/blob/main/time_r1/trainer/grpo_trainer_env.py#L874 (def _get_per_token_logps(self, model, input_ids, attention_mask, multimodal_inputs, logits_to_keep, batch_size=None) -> torch.Tensor:
is defined on line 1163
completion_mask = compute_tool_response_mask(completion_ids) * completion_mask which masks out tool outputs.

Why are the tool outputs not included in the attention mask for computing the log probabilities? Shouldn't they just be masked out for the loss?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions