diff --git a/chebi_utils/obo_extractor.py b/chebi_utils/obo_extractor.py index 1fc427f..fd427ca 100644 --- a/chebi_utils/obo_extractor.py +++ b/chebi_utils/obo_extractor.py @@ -63,7 +63,7 @@ def _term_data(doc: "fastobo.term.TermFrame") -> dict | None: } -def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: +def build_chebi_graph(filepath: str | Path, top_class: str | None = "23367") -> nx.DiGraph: """Parse a ChEBI OBO file and build a directed graph of ontology terms. ``xref:`` lines are stripped before parsing as they can cause fastobo @@ -82,6 +82,12 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: ---------- filepath : str or Path Path to the ChEBI OBO file. + top_class : str or None + CHEBI ID of the top-class (default "23367" for "molecular entity"). + This will only return direct or indirect subclasses of the + top-class (excluding the top-class). If ``top_class`` is not + present in the parsed graph, the full graph is returned. + If None, the full graph is returned without subgraph extraction. Returns ------- @@ -113,7 +119,16 @@ def build_chebi_graph(filepath: str | Path) -> nx.DiGraph: for part_id in parts: graph.add_edge(node_id, part_id, relation=relation) - return graph + if top_class is None: + return graph + + hierarchy = get_hierarchy_subgraph(graph) + if top_class not in hierarchy: + return graph + + chebi_subgraph = graph.subgraph(nx.ancestors(hierarchy, top_class)) + assert isinstance(chebi_subgraph, nx.DiGraph) + return chebi_subgraph def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph: @@ -122,3 +137,14 @@ def get_hierarchy_subgraph(chebi_graph: nx.DiGraph) -> nx.DiGraph: return chebi_graph.edge_subgraph( (u, v) for u, v, d in chebi_graph.edges(data=True) if d.get("relation") == "is_a" ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Build a ChEBI graph from an OBO file.") + parser.add_argument("obo_file", type=Path, help="Path to the ChEBI OBO file.") + args = parser.parse_args() + + graph = build_chebi_graph(args.obo_file) + print(f"Final graph: {len(graph.nodes)} nodes, {len(graph.edges)} edges")