1616
1717import torch
1818from torch import nn
19- from torch .nn .functional import fold , unfold
2019
2120from ...configuration_utils import ConfigMixin , register_to_config
2221from ...utils import logging
@@ -532,7 +531,19 @@ def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
532531 Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
533532 // patch_size)` is the number of patches.
534533 """
535- return unfold (img , kernel_size = patch_size , stride = patch_size ).transpose (1 , 2 )
534+ b , c , h , w = img .shape
535+ p = patch_size
536+
537+ # Reshape to (B, C, H//p, p, W//p, p) separating grid and patch dimensions
538+ img = img .reshape (b , c , h // p , p , w // p , p )
539+
540+ # Permute to (B, H//p, W//p, C, p, p) using einsum
541+ # n=batch, c=channels, h=grid_height, p=patch_height, w=grid_width, q=patch_width
542+ img = torch .einsum ("nchpwq->nhwcpq" , img )
543+
544+ # Flatten to (B, L, C * p * p)
545+ img = img .reshape (b , - 1 , c * p * p )
546+ return img
536547
537548
538549def seq2img (seq : torch .Tensor , patch_size : int , shape : torch .Tensor ) -> torch .Tensor :
@@ -554,12 +565,26 @@ def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Te
554565 Reconstructed image tensor of shape `(B, C, H, W)`.
555566 """
556567 if isinstance (shape , tuple ):
557- shape = shape [- 2 :]
568+ h , w = shape [- 2 :]
558569 elif isinstance (shape , torch .Tensor ):
559- shape = (int (shape [0 ]), int (shape [1 ]))
570+ h , w = (int (shape [0 ]), int (shape [1 ]))
560571 else :
561572 raise NotImplementedError (f"shape type { type (shape )} not supported" )
562- return fold (seq .transpose (1 , 2 ), shape , kernel_size = patch_size , stride = patch_size )
573+
574+ b , l , d = seq .shape
575+ p = patch_size
576+ c = d // (p * p )
577+
578+ # Reshape back to grid structure: (B, H//p, W//p, C, p, p)
579+ seq = seq .reshape (b , h // p , w // p , c , p , p )
580+
581+ # Permute back to image layout: (B, C, H//p, p, W//p, p)
582+ # n=batch, h=grid_height, w=grid_width, c=channels, p=patch_height, q=patch_width
583+ seq = torch .einsum ("nhwcpq->nchpwq" , seq )
584+
585+ # Final reshape to (B, C, H, W)
586+ seq = seq .reshape (b , c , h , w )
587+ return seq
563588
564589
565590class PRXTransformer2DModel (ModelMixin , ConfigMixin , AttentionMixin ):
0 commit comments