Skip to content

The number of channels is weird #296

@optstats

Description

@optstats

I run the unet.py on cifar-10 dataset, and I print the structure of UNet as follows, It seems not right ?

Input shape: torch.Size([128, 3, 32, 32])
Time embedding shape: torch.Size([128, 256])
After image_proj shape: torch.Size([128, 64, 32, 32])

Downsampling process:
Down block 1 output shape: torch.Size([128, 64, 32, 32])
Down block 2 output shape: torch.Size([128, 64, 32, 32])
Down block 3 output shape: torch.Size([128, 64, 16, 16])
Down block 4 output shape: torch.Size([128, 128, 16, 16])
Down block 5 output shape: torch.Size([128, 128, 16, 16])
Down block 6 output shape: torch.Size([128, 128, 8, 8])
Down block 7 output shape: torch.Size([128, 256, 8, 8])
Down block 8 output shape: torch.Size([128, 256, 8, 8])
Down block 9 output shape: torch.Size([128, 256, 4, 4])
Down block 10 output shape: torch.Size([128, 1024, 4, 4])
Down block 11 output shape: torch.Size([128, 1024, 4, 4])

Middle block:
Middle block output shape: torch.Size([128, 1024, 4, 4])

Upsampling process:
Concatenated input shape before up block 1: torch.Size([128, 2048, 4, 4])
Up block 1 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 2: torch.Size([128, 2048, 4, 4])
Up block 2 output shape: torch.Size([128, 1024, 4, 4])
Concatenated input shape before up block 3: torch.Size([128, 1280, 4, 4])
Up block 3 output shape: torch.Size([128, 256, 4, 4])
Upsample 4 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 5: torch.Size([128, 512, 8, 8])
Up block 5 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 6: torch.Size([128, 512, 8, 8])
Up block 6 output shape: torch.Size([128, 256, 8, 8])
Concatenated input shape before up block 7: torch.Size([128, 384, 8, 8])
Up block 7 output shape: torch.Size([128, 128, 8, 8])
Upsample 8 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 9: torch.Size([128, 256, 16, 16])
Up block 9 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 10: torch.Size([128, 256, 16, 16])
Up block 10 output shape: torch.Size([128, 128, 16, 16])
Concatenated input shape before up block 11: torch.Size([128, 192, 16, 16])
Up block 11 output shape: torch.Size([128, 64, 16, 16])
Upsample 12 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 13: torch.Size([128, 128, 32, 32])
Up block 13 output shape: torch.Size([128, 64, 32, 32])
Concatenated input shape before up block 14: torch.Size([128, 128, 32, 32])
Up block 14 output shape: torch.Size([128, 64, 32, 32])

Final output shape: torch.Size([128, 3, 32, 32])

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions