@@ -1411,7 +1411,22 @@ def transposed_convolution(
14111411 channel_last : bool = False ,
14121412) -> torch .Tensor :
14131413
1414+ # Cadence transposed conv receives weights that have been transformed by the pass:
1415+ # 1. Transposed (dims 0 and 1 swapped): [out_channels, in_channels, *kernel]
1416+ # 2. Flipped (spatial dimensions reversed)
1417+ # We need to reverse both transformations to call PyTorch's conv_transpose
1418+
14141419 conv_is_1d = len (input_tensor .shape ) == 3
1420+
1421+ # Determine flip dimensions based on weight dimensionality
1422+ weight_dim = len (weight .shape )
1423+ flip_dims = [- 1 ] if weight_dim == 3 else [- 1 , - 2 ]
1424+
1425+ # Reverse transformation step 1: Unflip the spatial dimensions
1426+ weight = torch .flip (weight , dims = flip_dims )
1427+
1428+ # Reverse transformation step 2: Transpose back to PyTorch format [in, out, *kernel]
1429+ weight = weight .transpose (0 , 1 ).contiguous ()
14151430 if channel_last :
14161431 if conv_is_1d :
14171432 input_tensor = input_tensor .movedim (- 1 , 1 ).contiguous ()
@@ -1863,12 +1878,13 @@ def transposed_im2row(
18631878 channel_last : bool = False ,
18641879) -> torch .Tensor :
18651880 """
1866- Converts input tensor patches into im2row format for transposed convolutions.
1867- This function extracts patches from input in a pattern suitable for transposed convolution.
1881+ Converts input tensor into im2row format for transposed convolutions.
1882+ For each output position, extracts the kernel-sized patch of input values that
1883+ contribute to that position in a transposed convolution.
18681884
18691885 Args:
18701886 - input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1871- - kernel_size: Size of the convolution kernel.
1887+ - kernel_size: Size of the convolution kernel (kernel_h, kernel_w) .
18721888 - dilation: Dilation of the convolution kernel.
18731889 - padding: Padding to apply to the input.
18741890 - stride: Stride of the convolution.
@@ -1893,117 +1909,136 @@ def transposed_im2row(
18931909 input_tensor = input_tensor .movedim (- 1 , - 3 ).contiguous () # NHWC -> NCHW
18941910
18951911 N , C , H_in , W_in = input_tensor .shape
1896-
1897- # Output: (N, C*H_in*W_in, H_out, W_out)
1898- H_out = (
1899- (H_in - 1 ) * stride [0 ]
1900- + kernel_size [0 ]
1901- + output_padding [0 ]
1902- - 2 * padding [0 ]
1903- + dilation [0 ] * (kernel_size [0 ] - 1 )
1904- )
1905- W_out = (
1906- (W_in - 1 ) * stride [1 ]
1907- + kernel_size [1 ]
1908- + output_padding [1 ]
1909- - 2 * padding [1 ]
1910- + dilation [1 ] * (kernel_size [1 ] - 1 )
1911- )
1912-
1913- # For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1914- # Output: (N, C*H_in*W_in, H_out, W_out)
1915- inp_flat = input_tensor .reshape (N , C * H_in * W_in )
1912+ K_h , K_w = kernel_size
1913+ device = input_tensor .device
19161914
19171915 # Calculate output spatial size
19181916 H_out = (
19191917 (H_in - 1 ) * stride [0 ]
19201918 - 2 * padding [0 ]
1921- + dilation [0 ] * (kernel_size [ 0 ] - 1 )
1919+ + dilation [0 ] * (K_h - 1 )
19221920 + output_padding [0 ]
19231921 + 1
19241922 )
19251923 W_out = (
19261924 (W_in - 1 ) * stride [1 ]
19271925 - 2 * padding [1 ]
1928- + dilation [1 ] * (kernel_size [ 1 ] - 1 )
1926+ + dilation [1 ] * (K_w - 1 )
19291927 + output_padding [1 ]
19301928 + 1
19311929 )
19321930
1933- # Compute the upsampled (top-left) position for each input pixel
1934- h_idx = torch .arange (H_in , device = input_tensor .device )
1935- w_idx = torch .arange (W_in , device = input_tensor .device )
1936- grid_h , grid_w = torch .meshgrid (h_idx , w_idx , indexing = "ij" )
1937- out_h_idx = grid_h * stride [0 ] - padding [0 ]
1938- out_w_idx = grid_w * stride [1 ] - padding [1 ]
1939-
1940- # Compute all input pixel positions (flattened)
1941- ch_idx = torch .arange (C * H_in * W_in , device = input_tensor .device )
1942- ij_idx = ch_idx % (H_in * W_in )
1943- i_idx = ij_idx // W_in
1944- j_idx = ij_idx % W_in
1945-
1946- # For each input pixel, compute the output positions for the kernel window
1947- kh_idx = torch .arange (kernel_size [0 ], device = input_tensor .device )
1948- kw_idx = torch .arange (kernel_size [1 ], device = input_tensor .device )
1949- kh_grid , kw_grid = torch .meshgrid (kh_idx , kw_idx , indexing = "ij" )
1950- kh_grid = kh_grid .reshape (- 1 )
1951- kw_grid = kw_grid .reshape (- 1 )
1952- num_kernel = kernel_size [0 ] * kernel_size [1 ]
1953-
1954- # Broadcast to all channels and kernel positions
1955- ch_idx_b = ch_idx .repeat_interleave (num_kernel )
1956- n_kernel = ch_idx .shape [0 ] * num_kernel
1957-
1958- i_idx_b = i_idx .repeat_interleave (num_kernel )
1959- j_idx_b = j_idx .repeat_interleave (num_kernel )
1960- kh_b = kh_grid .repeat (ch_idx .shape [0 ])
1961- kw_b = kw_grid .repeat (ch_idx .shape [0 ])
1962-
1963- h_out = out_h_idx [i_idx_b , j_idx_b ] + kh_b * dilation [0 ]
1964- w_out = out_w_idx [i_idx_b , j_idx_b ] + kw_b * dilation [1 ]
1965-
1966- # Mask for valid output positions
1967- valid = (h_out >= 0 ) & (h_out < H_out ) & (w_out >= 0 ) & (w_out < W_out )
1968-
1969- # Prepare indices for advanced indexing
1970- n_idx = (
1971- torch .arange (N , device = input_tensor .device )
1972- .view (- 1 , 1 )
1973- .expand (N , n_kernel )
1974- .reshape (- 1 )
1975- )
1976- ch_idx_full = ch_idx_b .expand (N , n_kernel ).reshape (- 1 )
1977- h_out_full = h_out .expand (N , n_kernel ).reshape (- 1 )
1978- w_out_full = w_out .expand (N , n_kernel ).reshape (- 1 )
1979- valid_full = valid .expand (N , n_kernel ).reshape (- 1 )
1980-
1981- # Gather input values for each channel
1982- inp_vals = inp_flat [:, ch_idx_b ].reshape (- 1 )
1983-
1984- # Create output tensor
1985- patches = torch .zeros ((N , C * H_in * W_in , H_out , W_out ), dtype = input_tensor .dtype )
1931+ # Create meshgrids for all output positions and kernel positions
1932+ h_out_grid = torch .arange (H_out , device = device ).view (
1933+ - 1 , 1 , 1 , 1
1934+ ) # [H_out, 1, 1, 1]
1935+ w_out_grid = torch .arange (W_out , device = device ).view (
1936+ 1 , - 1 , 1 , 1
1937+ ) # [1, W_out, 1, 1]
1938+ kh_grid = torch .arange (K_h , device = device ).view (1 , 1 , - 1 , 1 ) # [1, 1, K_h, 1]
1939+ kw_grid = torch .arange (K_w , device = device ).view (1 , 1 , 1 , - 1 ) # [1, 1, 1, K_w]
1940+
1941+ # Compute input positions for all (h_out, w_out, kh, kw) combinations
1942+ # From C++ reference: h_im = _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h
1943+ h_im = h_out_grid - (K_h - 1 ) * dilation [0 ] + kh_grid * dilation [0 ] + padding [0 ]
1944+ w_im = w_out_grid - (K_w - 1 ) * dilation [1 ] + kw_grid * dilation [1 ] + padding [1 ]
1945+
1946+ # Check which positions are valid (divisible by stride and within bounds)
1947+ # From C++ reference: if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0)
1948+ h_valid = (h_im % stride [0 ] == 0 ) & (h_im >= 0 ) & (h_im < stride [0 ] * H_in )
1949+ w_valid = (w_im % stride [1 ] == 0 ) & (w_im >= 0 ) & (w_im < stride [1 ] * W_in )
1950+ valid = h_valid & w_valid # [H_out, W_out, K_h, K_w]
1951+
1952+ # Actual input indices (h_im / stride_h from C++ reference)
1953+ h_in = h_im // stride [0 ]
1954+ w_in = w_im // stride [1 ]
1955+
1956+ # Clamp indices to valid range (will be masked out anyway)
1957+ h_in_safe = h_in .clamp (0 , H_in - 1 )
1958+ w_in_safe = w_in .clamp (0 , W_in - 1 )
1959+
1960+ # Initialize output patches with zero points (vectorized across batches)
1961+ # Layout depends on channel_last: NHWC uses [K_h, K_w, C], NCHW uses [C, K_h, K_w]
1962+ if channel_last :
1963+ # NHWC: patches layout [N, H_out, W_out, K_h, K_w, C]
1964+ patches = torch .zeros (
1965+ (N , H_out , W_out , K_h , K_w , C ),
1966+ dtype = input_tensor .dtype ,
1967+ device = device ,
1968+ )
1969+ else :
1970+ # NCHW: patches layout [N, H_out, W_out, C, K_h, K_w]
1971+ patches = torch .zeros (
1972+ (N , H_out , W_out , C , K_h , K_w ),
1973+ dtype = input_tensor .dtype ,
1974+ device = device ,
1975+ )
19861976
1987- # If in_zero_point is provided, fill patches with it
1977+ # Initialize patches with zero points (vectorized)
19881978 if in_zero_point is not None :
19891979 if in_zero_point .numel () == 1 :
1980+ # Scalar zero point - fill all patches
19901981 patches .fill_ (in_zero_point .item ())
19911982 else :
1992- # Broadcast in_zero_point to (N, C, H_in, W_in)
1993- assert in_zero_point .shape == (N ,)
1994- in_zero_point = in_zero_point .view (N , 1 , 1 , 1 )
1995- patches = patches + in_zero_point
1996-
1997- # Scatter input values to output positions (only valid positions)
1998- patches [
1999- n_idx [valid_full ],
2000- ch_idx_full [valid_full ],
2001- h_out_full [valid_full ],
2002- w_out_full [valid_full ],
2003- ] = inp_vals [valid_full ]
2004-
2005- # Optionally, flatten to (N, num_patches, patch_size) if needed
2006- patches = patches .view (N , C * H_in * W_in , - 1 ).transpose (1 , 2 ).contiguous ()
1983+ # Per-batch zero points - expand and fill
1984+ # in_zero_point: [N] -> [N, 1, 1, 1, 1, 1] or [N, 1, 1, 1, 1, 1]
1985+ zp_shape = [N ] + [1 ] * (patches .ndim - 1 )
1986+ patches = patches + in_zero_point .view (* zp_shape )
1987+
1988+ # Flatten the spatial and kernel dimensions for efficient gathering
1989+ # h_in_safe, w_in_safe: [H_out, W_out, K_h, K_w] (broadcast shape)
1990+ h_flat = h_in_safe .expand (H_out , W_out , K_h , K_w ).contiguous ().view (- 1 )
1991+ w_flat = w_in_safe .expand (H_out , W_out , K_h , K_w ).contiguous ().view (- 1 )
1992+
1993+ # Vectorized gathering across all batches and channels using advanced indexing
1994+ # Create index tensors with appropriate broadcasting shapes
1995+ num_positions = h_flat .shape [0 ]
1996+
1997+ # batch_indices: [N, 1, 1] -> broadcasts to [N, C, num_positions]
1998+ batch_indices = torch .arange (N , device = device ).view (N , 1 , 1 )
1999+
2000+ # channel_indices: [1, C, 1] -> broadcasts to [N, C, num_positions]
2001+ channel_indices = torch .arange (C , device = device ).view (1 , C , 1 )
2002+
2003+ # h_flat, w_flat: [1, 1, num_positions] -> broadcasts to [N, C, num_positions]
2004+ h_indices = h_flat .view (1 , 1 , num_positions )
2005+ w_indices = w_flat .view (1 , 1 , num_positions )
2006+
2007+ # Advanced indexing gathers all values at once: [N, C, num_positions]
2008+ gathered = input_tensor [batch_indices , channel_indices , h_indices , w_indices ]
2009+
2010+ # Reshape based on channel_last flag
2011+ if channel_last :
2012+ # NHWC: Reshape to [N, H_out, W_out, K_h, K_w, C]
2013+ # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, H_out*W_out*K_h*K_w, C] -> [N, H_out, W_out, K_h, K_w, C]
2014+ gathered = gathered .transpose (1 , 2 ).contiguous () # [N, num_positions, C]
2015+ gathered = gathered .view (N , H_out , W_out , K_h , K_w , C )
2016+ else :
2017+ # NCHW: Reshape to [N, H_out, W_out, C, K_h, K_w]
2018+ # gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, C, H_out, W_out, K_h, K_w] -> [N, H_out, W_out, C, K_h, K_w]
2019+ gathered = gathered .view (N , C , H_out , W_out , K_h , K_w )
2020+ gathered = gathered .permute (0 , 2 , 3 , 1 , 4 , 5 ).contiguous ()
2021+
2022+ # Apply validity mask (vectorized across batches)
2023+ # valid: [H_out, W_out, K_h, K_w] -> expand to match gathered shape
2024+ if channel_last :
2025+ # gathered: [N, H_out, W_out, K_h, K_w, C]
2026+ valid_exp = valid .unsqueeze (0 ).unsqueeze (- 1 ) # [1, H_out, W_out, K_h, K_w, 1]
2027+ else :
2028+ # gathered: [N, H_out, W_out, C, K_h, K_w]
2029+ valid_exp = valid .unsqueeze (0 ).unsqueeze (3 ) # [1, H_out, W_out, 1, K_h, K_w]
2030+
2031+ patches = torch .where (valid_exp , gathered , patches )
2032+
2033+ # Reshape to final output format: [N, H_out * W_out, K_h * K_w * C]
2034+ # The reshaping will preserve the correct dimension ordering
2035+ if channel_last :
2036+ # patches: [N, H_out, W_out, K_h, K_w, C] -> [N, H_out*W_out, K_h*K_w*C]
2037+ patches = patches .view (N , H_out * W_out , K_h * K_w * C )
2038+ else :
2039+ # patches: [N, H_out, W_out, C, K_h, K_w] -> [N, H_out*W_out, C*K_h*K_w]
2040+ patches = patches .view (N , H_out * W_out , C * K_h * K_w )
2041+
20072042 return patches
20082043
20092044
0 commit comments