From 801fb8159b29ecd142a6fe76ae5999fce35752a3 Mon Sep 17 00:00:00 2001 From: wnsplim Date: Thu, 28 May 2026 18:40:22 -0700 Subject: [PATCH] Attach per-haplotype tip labels from the input VCF --- SINGER/SINGER/convert_long_ARG.py | 30 +++++++++++++++ SINGER/SINGER/merge_ARG.py | 62 +++++++++++++++++++++++++++---- 2 files changed, 84 insertions(+), 8 deletions(-) diff --git a/SINGER/SINGER/convert_long_ARG.py b/SINGER/SINGER/convert_long_ARG.py index 8557ca8..e68b431 100644 --- a/SINGER/SINGER/convert_long_ARG.py +++ b/SINGER/SINGER/convert_long_ARG.py @@ -96,6 +96,36 @@ def write_output_ts(ts, output_prefix, MCMC_iteration): print(f"Save to {output_ts_filename}") ts.dump(output_ts_filename) +def read_vcf_sample_names(vcf_file): + for path in (vcf_file, vcf_file + ".vcf"): + if not os.path.exists(path): + continue + with open(path) as f: + for line in f: + if line.startswith("#CHROM"): + return line.strip().split("\t")[9:] + raise FileNotFoundError(f"Could not find VCF file: {vcf_file}") + +def add_individuals_from_vcf(ts, vcf_file): + sample_names = read_vcf_sample_names(vcf_file) + if len(sample_names) * 2 != ts.num_samples: + raise ValueError( + f"Expected {len(sample_names) * 2} sample nodes for {len(sample_names)} " + f"diploid VCF samples, got {ts.num_samples}." + ) + tables = ts.dump_tables() + node_individual = tables.nodes.individual.copy() + node_metadata = [b""] * tables.nodes.num_rows + for ind_id, name in enumerate(sample_names): + tables.individuals.add_row(metadata=name.encode()) + for hap in (0, 1): + nid = ind_id * 2 + hap + node_individual[nid] = ind_id + node_metadata[nid] = f"{name}_{hap}".encode() + tables.nodes.individual = node_individual + tables.nodes.packset_metadata(node_metadata) + return tables.tree_sequence() + def main(): # Argument parsing parser = argparse.ArgumentParser(description="Generate tskit format for a long ARG.") diff --git a/SINGER/SINGER/merge_ARG.py b/SINGER/SINGER/merge_ARG.py index 6bfd427..6f1c69c 100644 --- a/SINGER/SINGER/merge_ARG.py +++ b/SINGER/SINGER/merge_ARG.py @@ -72,8 +72,10 @@ def read_long_ARG(node_files, branch_files, mutation_files, block_coordinates): tables.mutations.add_row(site=site_id, node=int(mutations[i, 1]) + node_num, derived_state=str(int(mutations[i, 3]))) tables.sort() + tables.build_index() + tables.compute_mutation_parents() ts = tables.tree_sequence() - + return ts def load_file_lists(file_list_path): @@ -121,17 +123,58 @@ def sort_nodes_by_time(ts): child=node_map[edges.child].astype(np.int32), ) + muts = tables.mutations + muts.set_columns( + site=muts.site, + node=node_map[muts.node].astype(np.int32), + derived_state=muts.derived_state, + derived_state_offset=muts.derived_state_offset, + parent=muts.parent, + time=muts.time, + metadata=muts.metadata, + metadata_offset=muts.metadata_offset, + ) + tables.sort() tables.build_index() tables.compute_mutation_parents() tables.compute_mutation_times() - ts = tables.tree_sequence() return tables.tree_sequence() def write_output_ts(ts, output): print(f"Save to {output}") ts.dump(output) +def read_vcf_sample_names(vcf_file): + for path in (vcf_file, vcf_file + ".vcf"): + if not os.path.exists(path): + continue + with open(path) as f: + for line in f: + if line.startswith("#CHROM"): + return line.strip().split("\t")[9:] + raise FileNotFoundError(f"Could not find VCF file: {vcf_file}") + +def add_individuals_from_vcf(ts, vcf_file): + sample_names = read_vcf_sample_names(vcf_file) + if len(sample_names) * 2 != ts.num_samples: + raise ValueError( + f"Expected {len(sample_names) * 2} sample nodes for {len(sample_names)} " + f"diploid VCF samples, got {ts.num_samples}." + ) + tables = ts.dump_tables() + node_individual = tables.nodes.individual.copy() + node_metadata = [b""] * tables.nodes.num_rows + for ind_id, name in enumerate(sample_names): + tables.individuals.add_row(metadata=name.encode()) + for hap in (0, 1): + nid = ind_id * 2 + hap + node_individual[nid] = ind_id + node_metadata[nid] = f"{name}_{hap}".encode() + tables.nodes.individual = node_individual + tables.nodes.packset_metadata(node_metadata) + return tables.tree_sequence() + def main(): # Argument parsing parser = argparse.ArgumentParser(description="Generate tskit format for a long ARG.") @@ -139,17 +182,20 @@ def main(): # Add arguments with prefixes parser = argparse.ArgumentParser(description="Generate tskit format for a long ARG using file list.") parser.add_argument("--file_table", required=True, help="Sub file table") - parser.add_argument("--output", required=True, help="Output file name") - + parser.add_argument("--output", required=True, help="Output file name") + parser.add_argument("--vcf", required=False, default=None, + help="VCF (or prefix without .vcf) used as SINGER input. Sample " + "names from the header are attached as individuals, and tips " + "are named _0 / _1.") + args = parser.parse_args() - # Generate file lists node_files, branch_files, mutation_files, block_coordinates = load_file_lists(args.file_table) - # Apply the function - output_ts_filename = args.output ts = read_long_ARG(node_files, branch_files, mutation_files, block_coordinates) sorted_ts = sort_nodes_by_time(ts) - write_output_ts(sorted_ts, args.output) + if args.vcf is not None: + sorted_ts = add_individuals_from_vcf(sorted_ts, args.vcf) + write_output_ts(sorted_ts, args.output) if __name__ == "__main__": main()