From e9be5e1ecd25ac09c4b1bbf7c7f43c2cc519ed50 Mon Sep 17 00:00:00 2001 From: Dom <97384583+tosemml@users.noreply.github.com> Date: Tue, 29 Aug 2023 00:55:53 -0700 Subject: [PATCH 1/3] use list comp --- simlarity/feature_extraction.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/simlarity/feature_extraction.py b/simlarity/feature_extraction.py index 547972e..c6e679e 100644 --- a/simlarity/feature_extraction.py +++ b/simlarity/feature_extraction.py @@ -179,11 +179,7 @@ def _warn_graph_differences(train_tracer: NodePathTracer, eval_tracer: NodePathT def _get_leaf_modules_for_ops() -> List[type]: members = inspect.getmembers(torchvision.ops) - result = [] - for _, obj in members: - if inspect.isclass(obj) and issubclass(obj, torch.nn.Module): - result.append(obj) - return result + return [obj for _, obj in members if inspect.isclass(obj) and issubclass(obj, torch.nn.Module)] def get_graph_node_names( From 3c12c836cb4f89f871fa21e31fd031700c009dd9 Mon Sep 17 00:00:00 2001 From: Dom <97384583+tosemml@users.noreply.github.com> Date: Tue, 29 Aug 2023 00:57:13 -0700 Subject: [PATCH 2/3] use list comp --- simlarity/feature_extraction.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/simlarity/feature_extraction.py b/simlarity/feature_extraction.py index c6e679e..cc30455 100644 --- a/simlarity/feature_extraction.py +++ b/simlarity/feature_extraction.py @@ -492,10 +492,7 @@ def to_strdict(n) -> Dict[str, str]: ) # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) + orig_output_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "output"] assert len(orig_output_nodes) for n in orig_output_nodes: graph_module.graph.erase_node(n) From d0ae71c6154423da63e6ddae808514acb9e8ad90 Mon Sep 17 00:00:00 2001 From: Dom <97384583+tosemml@users.noreply.github.com> Date: Tue, 29 Aug 2023 00:59:30 -0700 Subject: [PATCH 3/3] use list comp --- simlarity/feature_extraction.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/simlarity/feature_extraction.py b/simlarity/feature_extraction.py index cc30455..c87aed2 100644 --- a/simlarity/feature_extraction.py +++ b/simlarity/feature_extraction.py @@ -647,19 +647,13 @@ def to_strdict(n) -> Dict[str, str]: ) # Remove existing output nodes (train mode) - orig_output_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "output": - orig_output_nodes.append(n) + orig_output_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "output"] assert len(orig_output_nodes) for n in orig_output_nodes: graph_module.graph.erase_node(n) # Remove existing input nodes (train mode) - orig_input_nodes = [] - for n in reversed(graph_module.graph.nodes): - if n.op == "placeholder": - orig_input_nodes.append(n) + orig_input_nodes = [n for n in reversed(graph_module.graph.nodes) if n.op == "placeholder"] assert len(orig_input_nodes) # for n in orig_input_nodes: # n.users=()