Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions models/latent_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.distributed as dist
import math
from einops import rearrange, repeat, reduce
from models.st_transformer import STTransformer, PatchEmbedding
from models.st_transformer import STTransformer, PatchEmbedding, SpatialAttention
from models.fsq import FiniteScalarQuantizer

NUM_LATENT_ACTIONS_BINS = 2
Expand All @@ -17,10 +17,13 @@ def __init__(self, frame_size=(128, 128), patch_size=8, embed_dim=128, num_heads
self.patch_embed = PatchEmbedding(frame_size, patch_size, embed_dim)
self.transformer = STTransformer(embed_dim, num_heads, hidden_dim, num_blocks, causal=True)

# windowed attention for action tokenization
self.window_attn = SpatialAttention(embed_dim, num_heads)

# embeddings to discrete latent bottleneck actions
self.action_head = nn.Sequential(
nn.LayerNorm(embed_dim * 2),
nn.Linear(embed_dim * 2, 4 * action_dim),
nn.LayerNorm(embed_dim),
nn.Linear(embed_dim, 4 * action_dim),
nn.GELU(),
nn.Linear(4 * action_dim, action_dim)
)
Expand All @@ -32,19 +35,27 @@ def forward(self, frames):
embeddings = self.patch_embed(frames) # [B, T, P, E]
transformed = self.transformer(embeddings)

# TODO: try attention pooling + mean instead of mean + concat
# mean pool over patches (since one action per frame)
pooled = transformed.mean(dim=2) # [B, T, E]

# combine features from current and next frame
actions = []
for t in range(seq_len - 1):
# concat current and next frame features
combined = torch.cat([pooled[:, t], pooled[:, t+1]], dim=1) # [B, E*2]
action = self.action_head(combined) # [B, A]
actions.append(action)
# windowed attention over length-2 windows (current and next frame)
# we want to combine frame t and t+1

# 1. Create windows: [B, T-1, 2, P, E]
# We slice to get T-1 windows
current_frames = transformed[:, :-1] # [B, T-1, P, E]
next_frames = transformed[:, 1:] # [B, T-1, P, E]

# 2. Concatenate along patch dimension to treat as one large spatial sequence
# [B, T-1, 2*P, E]
windows = torch.cat([current_frames, next_frames], dim=2)

# 3. Apply spatial attention
# SpatialAttention expects [B, T, P, E]
attended = self.window_attn(windows) # [B, T-1, 2*P, E]

# 4. Mean pool over the combined patches
pooled = attended.mean(dim=2) # [B, T-1, E]

actions = torch.stack(actions, dim=1) # [B, T-1, A]
# 5. Project to actions
actions = self.action_head(pooled) # [B, T-1, A]

return actions

Expand Down