Skip to content

Commit 874cb34

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Update more ops to use the new pass interface, update some ref implementations
Summary: Updated - FuseCascadedViewOps - ReplaceAtenConvolutionWithCadenceConvolutionPass - ReplaceTransposedConvWithLinearPass - ReplaceNopTransposeOrPermuteWithViewPass - ReplaceLinearWithFullyConnectedOpPass to use the new efficient pass interface which also correctly sets the modified bit. Also fix the transposed_im2row since numerical validation was failing for tests that introduced those ops. Reviewed By: hsharma35 Differential Revision: D87837579
1 parent 78993e8 commit 874cb34

File tree

7 files changed

+583
-381
lines changed

7 files changed

+583
-381
lines changed

backends/cadence/aot/fuse_ops.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from executorch.exir.dialects._ops import ops as exir_ops
4141
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
4242
from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue
43-
from executorch.exir.passes import dead_code_elimination_pass
4443
from executorch.exir.passes.spec_prop_pass import SpecPropPass
4544
from torch.fx.node import Argument
4645
from torch.nn.utils.fusion import fuse_conv_bn_weights
@@ -523,29 +522,31 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
523522

524523

525524
@register_cadence_pass(CadencePassAttribute(opt_level=1))
526-
class FuseCascadedViewOps(ExportPass):
525+
class FuseCascadedViewOps(RemoveOrReplacePassInterface):
527526
"""
528527
Fuse a cascaded chain of view ops
529528
"""
530529

531-
def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule):
532-
view_target = exir_ops.edge.aten.view_copy.default
533-
for view_node in graph_module.graph.find_nodes(
534-
op="call_function", target=view_target, sort=True
535-
):
536-
input_view = view_node.args[0]
537-
if input_view.op != "call_function" or input_view.target != view_target:
538-
continue
530+
@property
531+
def targets(self) -> list[EdgeOpOverload]:
532+
return [exir_ops.edge.aten.view_copy.default]
539533

540-
view_node.replace_input_with(input_view, input_view.args[0])
534+
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
535+
# Check if the input to this view node is also a view node
536+
input_view = node.args[0]
537+
if not isinstance(input_view, torch.fx.Node):
538+
return False
541539

542-
graph_module.recompile()
540+
if (
541+
input_view.op != "call_function"
542+
or input_view.target != exir_ops.edge.aten.view_copy.default
543+
):
544+
return False
543545

544-
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
545-
self.fuse_cascaded_view_ops(graph_module)
546-
dead_code_elimination_pass(graph_module)
547-
result = super().call(graph_module)
548-
return result
546+
# Replace the input of this view node with the input of the cascaded view
547+
# This effectively "skips" the intermediate view node
548+
node.replace_input_with(input_view, cast(torch.fx.Node, input_view.args[0]))
549+
return True
549550

550551

551552
class FuseOpPairsAcrossBranchesPass(ExportPass):

backends/cadence/aot/ops_registrations.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2439,6 +2439,11 @@ def transposed_im2row_meta(
24392439
in_zero_point: torch.Tensor,
24402440
channel_last: bool = False,
24412441
) -> torch.Tensor:
2442+
"""
2443+
Shape inference for transposed_im2row operation.
2444+
2445+
Returns shape: (N, H_out * W_out, K_h * K_w * C_in)
2446+
"""
24422447
if len(input.shape) == 3:
24432448
height_dim = 1 if channel_last else 2
24442449
input = input.unsqueeze(height_dim)
@@ -2447,6 +2452,8 @@ def transposed_im2row_meta(
24472452
n_input_plane = input.shape[3] if channel_last else input.shape[1]
24482453
input_height = input.shape[1] if channel_last else input.shape[2]
24492454
input_width = input.shape[2] if channel_last else input.shape[3]
2455+
2456+
# Calculate output spatial dimensions
24502457
output_height = (
24512458
(input_height - 1) * stride[0]
24522459
- 2 * padding[0]
@@ -2461,9 +2468,11 @@ def transposed_im2row_meta(
24612468
+ output_padding[1]
24622469
+ 1
24632470
)
2464-
n_output_plane = n_input_plane * kernel_size[0] * kernel_size[1]
2465-
output_length = output_height * output_width
2466-
output_size = torch.Size((batch_size, output_length, n_output_plane))
2471+
2472+
# Patch size is kernel_h * kernel_w * in_channels
2473+
patch_size = kernel_size[0] * kernel_size[1] * n_input_plane
2474+
num_patches = output_height * output_width
2475+
output_size = torch.Size((batch_size, num_patches, patch_size))
24672476

24682477
return input.new_empty(output_size, dtype=input.dtype)
24692478

backends/cadence/aot/ref_implementations.py

Lines changed: 129 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,7 +1411,22 @@ def transposed_convolution(
14111411
channel_last: bool = False,
14121412
) -> torch.Tensor:
14131413

1414+
# Cadence transposed conv receives weights that have been transformed by the pass:
1415+
# 1. Transposed (dims 0 and 1 swapped): [out_channels, in_channels, *kernel]
1416+
# 2. Flipped (spatial dimensions reversed)
1417+
# We need to reverse both transformations to call PyTorch's conv_transpose
1418+
14141419
conv_is_1d = len(input_tensor.shape) == 3
1420+
1421+
# Determine flip dimensions based on weight dimensionality
1422+
weight_dim = len(weight.shape)
1423+
flip_dims = [-1] if weight_dim == 3 else [-1, -2]
1424+
1425+
# Reverse transformation step 1: Unflip the spatial dimensions
1426+
weight = torch.flip(weight, dims=flip_dims)
1427+
1428+
# Reverse transformation step 2: Transpose back to PyTorch format [in, out, *kernel]
1429+
weight = weight.transpose(0, 1).contiguous()
14151430
if channel_last:
14161431
if conv_is_1d:
14171432
input_tensor = input_tensor.movedim(-1, 1).contiguous()
@@ -1863,12 +1878,13 @@ def transposed_im2row(
18631878
channel_last: bool = False,
18641879
) -> torch.Tensor:
18651880
"""
1866-
Converts input tensor patches into im2row format for transposed convolutions.
1867-
This function extracts patches from input in a pattern suitable for transposed convolution.
1881+
Converts input tensor into im2row format for transposed convolutions.
1882+
For each output position, extracts the kernel-sized patch of input values that
1883+
contribute to that position in a transposed convolution.
18681884
18691885
Args:
18701886
- input_tensor: Input spatial tensor, NCHW or NHWC format (3D or 4D).
1871-
- kernel_size: Size of the convolution kernel.
1887+
- kernel_size: Size of the convolution kernel (kernel_h, kernel_w).
18721888
- dilation: Dilation of the convolution kernel.
18731889
- padding: Padding to apply to the input.
18741890
- stride: Stride of the convolution.
@@ -1893,117 +1909,136 @@ def transposed_im2row(
18931909
input_tensor = input_tensor.movedim(-1, -3).contiguous() # NHWC -> NCHW
18941910

18951911
N, C, H_in, W_in = input_tensor.shape
1896-
1897-
# Output: (N, C*H_in*W_in, H_out, W_out)
1898-
H_out = (
1899-
(H_in - 1) * stride[0]
1900-
+ kernel_size[0]
1901-
+ output_padding[0]
1902-
- 2 * padding[0]
1903-
+ dilation[0] * (kernel_size[0] - 1)
1904-
)
1905-
W_out = (
1906-
(W_in - 1) * stride[1]
1907-
+ kernel_size[1]
1908-
+ output_padding[1]
1909-
- 2 * padding[1]
1910-
+ dilation[1] * (kernel_size[1] - 1)
1911-
)
1912-
1913-
# For each input pixel, create a channel where the upsampled (transposed conv) patch is placed
1914-
# Output: (N, C*H_in*W_in, H_out, W_out)
1915-
inp_flat = input_tensor.reshape(N, C * H_in * W_in)
1912+
K_h, K_w = kernel_size
1913+
device = input_tensor.device
19161914

19171915
# Calculate output spatial size
19181916
H_out = (
19191917
(H_in - 1) * stride[0]
19201918
- 2 * padding[0]
1921-
+ dilation[0] * (kernel_size[0] - 1)
1919+
+ dilation[0] * (K_h - 1)
19221920
+ output_padding[0]
19231921
+ 1
19241922
)
19251923
W_out = (
19261924
(W_in - 1) * stride[1]
19271925
- 2 * padding[1]
1928-
+ dilation[1] * (kernel_size[1] - 1)
1926+
+ dilation[1] * (K_w - 1)
19291927
+ output_padding[1]
19301928
+ 1
19311929
)
19321930

1933-
# Compute the upsampled (top-left) position for each input pixel
1934-
h_idx = torch.arange(H_in, device=input_tensor.device)
1935-
w_idx = torch.arange(W_in, device=input_tensor.device)
1936-
grid_h, grid_w = torch.meshgrid(h_idx, w_idx, indexing="ij")
1937-
out_h_idx = grid_h * stride[0] - padding[0]
1938-
out_w_idx = grid_w * stride[1] - padding[1]
1939-
1940-
# Compute all input pixel positions (flattened)
1941-
ch_idx = torch.arange(C * H_in * W_in, device=input_tensor.device)
1942-
ij_idx = ch_idx % (H_in * W_in)
1943-
i_idx = ij_idx // W_in
1944-
j_idx = ij_idx % W_in
1945-
1946-
# For each input pixel, compute the output positions for the kernel window
1947-
kh_idx = torch.arange(kernel_size[0], device=input_tensor.device)
1948-
kw_idx = torch.arange(kernel_size[1], device=input_tensor.device)
1949-
kh_grid, kw_grid = torch.meshgrid(kh_idx, kw_idx, indexing="ij")
1950-
kh_grid = kh_grid.reshape(-1)
1951-
kw_grid = kw_grid.reshape(-1)
1952-
num_kernel = kernel_size[0] * kernel_size[1]
1953-
1954-
# Broadcast to all channels and kernel positions
1955-
ch_idx_b = ch_idx.repeat_interleave(num_kernel)
1956-
n_kernel = ch_idx.shape[0] * num_kernel
1957-
1958-
i_idx_b = i_idx.repeat_interleave(num_kernel)
1959-
j_idx_b = j_idx.repeat_interleave(num_kernel)
1960-
kh_b = kh_grid.repeat(ch_idx.shape[0])
1961-
kw_b = kw_grid.repeat(ch_idx.shape[0])
1962-
1963-
h_out = out_h_idx[i_idx_b, j_idx_b] + kh_b * dilation[0]
1964-
w_out = out_w_idx[i_idx_b, j_idx_b] + kw_b * dilation[1]
1965-
1966-
# Mask for valid output positions
1967-
valid = (h_out >= 0) & (h_out < H_out) & (w_out >= 0) & (w_out < W_out)
1968-
1969-
# Prepare indices for advanced indexing
1970-
n_idx = (
1971-
torch.arange(N, device=input_tensor.device)
1972-
.view(-1, 1)
1973-
.expand(N, n_kernel)
1974-
.reshape(-1)
1975-
)
1976-
ch_idx_full = ch_idx_b.expand(N, n_kernel).reshape(-1)
1977-
h_out_full = h_out.expand(N, n_kernel).reshape(-1)
1978-
w_out_full = w_out.expand(N, n_kernel).reshape(-1)
1979-
valid_full = valid.expand(N, n_kernel).reshape(-1)
1980-
1981-
# Gather input values for each channel
1982-
inp_vals = inp_flat[:, ch_idx_b].reshape(-1)
1983-
1984-
# Create output tensor
1985-
patches = torch.zeros((N, C * H_in * W_in, H_out, W_out), dtype=input_tensor.dtype)
1931+
# Create meshgrids for all output positions and kernel positions
1932+
h_out_grid = torch.arange(H_out, device=device).view(
1933+
-1, 1, 1, 1
1934+
) # [H_out, 1, 1, 1]
1935+
w_out_grid = torch.arange(W_out, device=device).view(
1936+
1, -1, 1, 1
1937+
) # [1, W_out, 1, 1]
1938+
kh_grid = torch.arange(K_h, device=device).view(1, 1, -1, 1) # [1, 1, K_h, 1]
1939+
kw_grid = torch.arange(K_w, device=device).view(1, 1, 1, -1) # [1, 1, 1, K_w]
1940+
1941+
# Compute input positions for all (h_out, w_out, kh, kw) combinations
1942+
# From C++ reference: h_im = _h - ((kernel_h - 1) * dilation_h) + _kh * dilation_h + pad_h
1943+
h_im = h_out_grid - (K_h - 1) * dilation[0] + kh_grid * dilation[0] + padding[0]
1944+
w_im = w_out_grid - (K_w - 1) * dilation[1] + kw_grid * dilation[1] + padding[1]
1945+
1946+
# Check which positions are valid (divisible by stride and within bounds)
1947+
# From C++ reference: if (h_im < 0 || h_im >= stride_h * height || h_im % stride_h != 0)
1948+
h_valid = (h_im % stride[0] == 0) & (h_im >= 0) & (h_im < stride[0] * H_in)
1949+
w_valid = (w_im % stride[1] == 0) & (w_im >= 0) & (w_im < stride[1] * W_in)
1950+
valid = h_valid & w_valid # [H_out, W_out, K_h, K_w]
1951+
1952+
# Actual input indices (h_im / stride_h from C++ reference)
1953+
h_in = h_im // stride[0]
1954+
w_in = w_im // stride[1]
1955+
1956+
# Clamp indices to valid range (will be masked out anyway)
1957+
h_in_safe = h_in.clamp(0, H_in - 1)
1958+
w_in_safe = w_in.clamp(0, W_in - 1)
1959+
1960+
# Initialize output patches with zero points (vectorized across batches)
1961+
# Layout depends on channel_last: NHWC uses [K_h, K_w, C], NCHW uses [C, K_h, K_w]
1962+
if channel_last:
1963+
# NHWC: patches layout [N, H_out, W_out, K_h, K_w, C]
1964+
patches = torch.zeros(
1965+
(N, H_out, W_out, K_h, K_w, C),
1966+
dtype=input_tensor.dtype,
1967+
device=device,
1968+
)
1969+
else:
1970+
# NCHW: patches layout [N, H_out, W_out, C, K_h, K_w]
1971+
patches = torch.zeros(
1972+
(N, H_out, W_out, C, K_h, K_w),
1973+
dtype=input_tensor.dtype,
1974+
device=device,
1975+
)
19861976

1987-
# If in_zero_point is provided, fill patches with it
1977+
# Initialize patches with zero points (vectorized)
19881978
if in_zero_point is not None:
19891979
if in_zero_point.numel() == 1:
1980+
# Scalar zero point - fill all patches
19901981
patches.fill_(in_zero_point.item())
19911982
else:
1992-
# Broadcast in_zero_point to (N, C, H_in, W_in)
1993-
assert in_zero_point.shape == (N,)
1994-
in_zero_point = in_zero_point.view(N, 1, 1, 1)
1995-
patches = patches + in_zero_point
1996-
1997-
# Scatter input values to output positions (only valid positions)
1998-
patches[
1999-
n_idx[valid_full],
2000-
ch_idx_full[valid_full],
2001-
h_out_full[valid_full],
2002-
w_out_full[valid_full],
2003-
] = inp_vals[valid_full]
2004-
2005-
# Optionally, flatten to (N, num_patches, patch_size) if needed
2006-
patches = patches.view(N, C * H_in * W_in, -1).transpose(1, 2).contiguous()
1983+
# Per-batch zero points - expand and fill
1984+
# in_zero_point: [N] -> [N, 1, 1, 1, 1, 1] or [N, 1, 1, 1, 1, 1]
1985+
zp_shape = [N] + [1] * (patches.ndim - 1)
1986+
patches = patches + in_zero_point.view(*zp_shape)
1987+
1988+
# Flatten the spatial and kernel dimensions for efficient gathering
1989+
# h_in_safe, w_in_safe: [H_out, W_out, K_h, K_w] (broadcast shape)
1990+
h_flat = h_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1)
1991+
w_flat = w_in_safe.expand(H_out, W_out, K_h, K_w).contiguous().view(-1)
1992+
1993+
# Vectorized gathering across all batches and channels using advanced indexing
1994+
# Create index tensors with appropriate broadcasting shapes
1995+
num_positions = h_flat.shape[0]
1996+
1997+
# batch_indices: [N, 1, 1] -> broadcasts to [N, C, num_positions]
1998+
batch_indices = torch.arange(N, device=device).view(N, 1, 1)
1999+
2000+
# channel_indices: [1, C, 1] -> broadcasts to [N, C, num_positions]
2001+
channel_indices = torch.arange(C, device=device).view(1, C, 1)
2002+
2003+
# h_flat, w_flat: [1, 1, num_positions] -> broadcasts to [N, C, num_positions]
2004+
h_indices = h_flat.view(1, 1, num_positions)
2005+
w_indices = w_flat.view(1, 1, num_positions)
2006+
2007+
# Advanced indexing gathers all values at once: [N, C, num_positions]
2008+
gathered = input_tensor[batch_indices, channel_indices, h_indices, w_indices]
2009+
2010+
# Reshape based on channel_last flag
2011+
if channel_last:
2012+
# NHWC: Reshape to [N, H_out, W_out, K_h, K_w, C]
2013+
# gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, H_out*W_out*K_h*K_w, C] -> [N, H_out, W_out, K_h, K_w, C]
2014+
gathered = gathered.transpose(1, 2).contiguous() # [N, num_positions, C]
2015+
gathered = gathered.view(N, H_out, W_out, K_h, K_w, C)
2016+
else:
2017+
# NCHW: Reshape to [N, H_out, W_out, C, K_h, K_w]
2018+
# gathered: [N, C, H_out*W_out*K_h*K_w] -> [N, C, H_out, W_out, K_h, K_w] -> [N, H_out, W_out, C, K_h, K_w]
2019+
gathered = gathered.view(N, C, H_out, W_out, K_h, K_w)
2020+
gathered = gathered.permute(0, 2, 3, 1, 4, 5).contiguous()
2021+
2022+
# Apply validity mask (vectorized across batches)
2023+
# valid: [H_out, W_out, K_h, K_w] -> expand to match gathered shape
2024+
if channel_last:
2025+
# gathered: [N, H_out, W_out, K_h, K_w, C]
2026+
valid_exp = valid.unsqueeze(0).unsqueeze(-1) # [1, H_out, W_out, K_h, K_w, 1]
2027+
else:
2028+
# gathered: [N, H_out, W_out, C, K_h, K_w]
2029+
valid_exp = valid.unsqueeze(0).unsqueeze(3) # [1, H_out, W_out, 1, K_h, K_w]
2030+
2031+
patches = torch.where(valid_exp, gathered, patches)
2032+
2033+
# Reshape to final output format: [N, H_out * W_out, K_h * K_w * C]
2034+
# The reshaping will preserve the correct dimension ordering
2035+
if channel_last:
2036+
# patches: [N, H_out, W_out, K_h, K_w, C] -> [N, H_out*W_out, K_h*K_w*C]
2037+
patches = patches.view(N, H_out * W_out, K_h * K_w * C)
2038+
else:
2039+
# patches: [N, H_out, W_out, C, K_h, K_w] -> [N, H_out*W_out, C*K_h*K_w]
2040+
patches = patches.view(N, H_out * W_out, C * K_h * K_w)
2041+
20072042
return patches
20082043

20092044

0 commit comments

Comments
 (0)