Skip to content

Commit 841d50f

Browse files
committed
Merge branch 'feat-webdataset-integration'
2 parents a9c0f5f + 14d256a commit 841d50f

5 files changed

Lines changed: 245 additions & 177 deletions

File tree

bichrom.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ dependencies:
1616
- pybedtools
1717
- pybigwig
1818
- pyfasta
19+
- webdataset=0.2
1920
# NN
2021
- pytorch=1.11
2122
- torchvision
2223
- torchaudio
2324
- cudatoolkit=11.3
24-
- pytorch-lightning=1.6
25+
- pytorch-lightning=1.7
2526
- jsonargparse
2627
- docstring_parser
2728
- torchmetrics=0.8
28-
- tensorflow=2.8

construct_data/construct_data.py

Lines changed: 52 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import os
21
import argparse
32
import yaml
43
import subprocess
5-
import numpy as np
64
import pandas as pd
75
from pybedtools import BedTool
86
from subprocess import call
@@ -132,7 +130,7 @@ def define_training_coordinates(chip_coords: pd.DataFrame, genome_sizes_file: st
132130
return training_coords_seq, training_coords_bichrom
133131

134132
def construct_training_set(genome_sizes_file, genome_fasta_file, peaks_file, blacklist_file, to_keep, to_filter,
135-
window_length, acc_regions_file, out_prefix, chromatin_track_list, tf_bam, nbins, augment_factor=5, p=1):
133+
window_length, acc_regions_file, out_prefix, chromatin_track_list, tf_bam, nbins, augment_factor=5, p=1, compress=False):
136134

137135
# prepare files for defining coordiantes
138136
curr_genome_bdt = utils.get_genome_sizes(genome_sizes_file, to_keep=to_keep, to_filter=to_filter)
@@ -147,25 +145,28 @@ def construct_training_set(genome_sizes_file, genome_fasta_file, peaks_file, bla
147145
# get the coordinates for training samples
148146
train_coords_seq, train_coords_bichrom = define_training_coordinates(chip_seq_coordinates, genome_sizes_file, acc_bdt, curr_genome_bdt,
149147
blacklist_bdt, window_length, len(chip_seq_coordinates)*augment_factor, [450, -450, 500, -500, 1250, -1250, 1750, -1750], None, None)
150-
train_coords_seq.to_csv(out_prefix + "_seq.bed", header=False, index=False, sep="\t")
151-
train_coords_bichrom.to_csv(out_prefix + "_bichrom.bed", header=False, index=False, sep="\t")
148+
# save coordinates in bed files
149+
train_coords_seq_bed = out_prefix + "_seq.bed"
150+
train_coords_bichrom_bed = out_prefix + "_bichrom.bed"
151+
train_coords_seq.to_csv(train_coords_seq_bed, header=False, index=False, sep="\t")
152+
train_coords_bichrom.to_csv(train_coords_bichrom_bed, header=False, index=False, sep="\t")
152153

153154
# get fasta sequence and chromatin coverage according to the coordinates
154155
# write TFRecord output
155156
chroms_scaler = StandardScaler()
156-
TFRecord_file_seq_f = utils.get_data_TFRecord(train_coords_seq, genome_fasta_file, chromatin_track_list, tf_bam,
157-
nbins, outprefix=out_prefix + "_seq_forward" ,reverse=False, numProcessors=p)
158-
TFRecord_file_seq_r = utils.get_data_TFRecord(train_coords_seq, genome_fasta_file, chromatin_track_list, tf_bam,
159-
nbins, outprefix=out_prefix + "_seq_reverse",reverse=True, numProcessors=p)
160-
TFRecord_file_bichrom_f = utils.get_data_TFRecord(train_coords_bichrom, genome_fasta_file, chromatin_track_list, tf_bam,
161-
nbins, outprefix=out_prefix + "_bichrom_forward" ,reverse=False, numProcessors=p, chroms_scaler=chroms_scaler)
162-
TFRecord_file_bichrom_r = utils.get_data_TFRecord(train_coords_bichrom, genome_fasta_file, chromatin_track_list, tf_bam,
163-
nbins, outprefix=out_prefix + "_bichrom_reverse",reverse=True, numProcessors=p)
157+
wds_file_seq_f = utils.get_data_webdataset(train_coords_seq, genome_fasta_file, chromatin_track_list, tf_bam,
158+
nbins, outprefix=out_prefix + "_seq_forward" ,reverse=False, compress=compress, numProcessors=p)
159+
wds_file_seq_r = utils.get_data_webdataset(train_coords_seq, genome_fasta_file, chromatin_track_list, tf_bam,
160+
nbins, outprefix=out_prefix + "_seq_reverse",reverse=True, compress=compress, numProcessors=p)
161+
wds_file_bichrom_f = utils.get_data_webdataset(train_coords_bichrom, genome_fasta_file, chromatin_track_list, tf_bam,
162+
nbins, outprefix=out_prefix + "_bichrom_forward" ,reverse=False, compress=compress, numProcessors=p, chroms_scaler=chroms_scaler)
163+
wds_file_bichrom_r = utils.get_data_webdataset(train_coords_bichrom, genome_fasta_file, chromatin_track_list, tf_bam,
164+
nbins, outprefix=out_prefix + "_bichrom_reverse",reverse=True, compress=compress, numProcessors=p)
164165

165-
return TFRecord_file_seq_f + TFRecord_file_seq_r, TFRecord_file_bichrom_f + TFRecord_file_bichrom_r, chroms_scaler
166+
return wds_file_seq_f + wds_file_seq_r, wds_file_bichrom_f + wds_file_bichrom_r, train_coords_seq_bed, train_coords_bichrom_bed, chroms_scaler
166167

167168
def construct_test_set(genome_sizes_file, genome_fasta_file, peaks_file, blacklist_file, to_keep,
168-
window_length, stride, out_prefix, chromatin_track_list, tf_bam, nbins, p=1):
169+
window_length, stride, out_prefix, chromatin_track_list, tf_bam, nbins, p=1, compress=False):
169170

170171
# prepare file for defining coordinates
171172
blacklist_bdt = BedTool(blacklist_file)
@@ -183,13 +184,14 @@ def construct_test_set(genome_sizes_file, genome_fasta_file, peaks_file, blackli
183184
.assign(label=0, type="neg_chop"))
184185

185186
test_coords = pd.concat([bound_chip_peaks, unbound_genome_chop])
186-
test_coords.to_csv(out_prefix + ".bed", header=False, index=False, sep="\t")
187+
test_coords_bed = out_prefix + ".bed"
188+
test_coords.to_csv(test_coords_bed, header=False, index=False, sep="\t")
187189

188190
# write TFRecord output
189-
TFRecord_file = utils.get_data_TFRecord(test_coords, genome_fasta_file, chromatin_track_list, tf_bam,
190-
nbins, outprefix=out_prefix + "_forward" ,reverse=False, numProcessors=p)
191+
wds_file = utils.get_data_webdataset(test_coords, genome_fasta_file, chromatin_track_list, tf_bam,
192+
nbins, outprefix=out_prefix + "_forward" ,reverse=False, compress=compress, numProcessors=p)
191193

192-
return TFRecord_file
194+
return wds_file, test_coords_bed
193195

194196
def main():
195197

@@ -210,11 +212,12 @@ def main():
210212
required=True)
211213
parser.add_argument('-o', '--outdir', help='Output directory for storing train, test data',
212214
required=True)
213-
parser.add_argument('-nbins', type=int, help='Number of bins for chromatin tracks',
214-
required=True)
215-
parser.add_argument('-augment', type=int, help='Upsample positive set to AUGMENT times', default=5),
215+
parser.add_argument('-augment', type=int, help='Upsample positive set to AUGMENT times', default=5)
216+
216217
parser.add_argument('-p', type=int, help='Number of processors', default=1)
217218

219+
parser.add_argument('-compress', action='store_true', help='Whether compress input datasets', default=False)
220+
218221
parser.add_argument('-blacklist', default=None, help='Optional, blacklist file for the genome of interest')
219222

220223
parser.add_argument('-val_chroms', default=['chr11'], nargs='+', help='A list of chromosomes to use for the validation set.')
@@ -256,7 +259,7 @@ def main():
256259
print([x.split('/')[-1].split('.')[0] for x in args.chromtracks])
257260

258261
print('Constructing train data ...')
259-
TFRecords_train_seq, TFRecords_train_bichrom, chroms_scaler = construct_training_set(genome_sizes_file=args.info,
262+
wds_train_seq, wds_train_bichrom, train_coords_seq_bed, train_coords_bichrom_bed, chroms_scaler = construct_training_set(genome_sizes_file=args.info,
260263
genome_fasta_file=args.fa,
261264
peaks_file=args.peaks,
262265
blacklist_file=args.blacklist, window_length=args.len,
@@ -266,12 +269,13 @@ def main():
266269
out_prefix=args.outdir + '/data_train',
267270
chromatin_track_list=args.chromtracks,
268271
tf_bam=args.tfbam,
269-
nbins=args.nbins,
272+
nbins=args.len,
270273
augment_factor=args.augment,
271-
p=args.p)
274+
p=args.p,
275+
compress=args.compress)
272276

273277
print('Constructing validation data ...')
274-
TFRecords_val = construct_test_set(genome_sizes_file=args.info,
278+
wds_val, val_coords_bed = construct_test_set(genome_sizes_file=args.info,
275279
peaks_file=args.peaks,
276280
genome_fasta_file=args.fa,
277281
blacklist_file=args.blacklist, window_length=args.len,
@@ -280,10 +284,10 @@ def main():
280284
out_prefix=args.outdir + '/data_val',
281285
chromatin_track_list=args.chromtracks,
282286
tf_bam=args.tfbam,
283-
nbins=args.nbins, p=args.p)
287+
nbins=args.len, p=args.p, compress=args.compress)
284288

285289
print('Constructing test data ...')
286-
TFRecords_test = construct_test_set(genome_sizes_file=args.info,
290+
wds_test, test_coords_bed = construct_test_set(genome_sizes_file=args.info,
287291
peaks_file=args.peaks,
288292
genome_fasta_file=args.fa,
289293
blacklist_file=args.blacklist, window_length=args.len,
@@ -292,52 +296,32 @@ def main():
292296
out_prefix=args.outdir + '/data_test',
293297
chromatin_track_list=args.chromtracks,
294298
tf_bam=args.tfbam,
295-
nbins=args.nbins, p=args.p)
299+
nbins=args.len, p=args.p, compress=args.compress)
296300

297301
# Produce a default yaml file recording the output
298-
yml_training_schema = {'train_seq': {'seq': 'seq',
299-
'labels': 'labels',
302+
yml_training_schema = {'params': {
300303
'chromatin_tracks': args.chromtracks,
301304
'tf_bam': args.tfbam,
302305
'fasta': args.fa,
303-
'nbins': args.nbins,
304-
'TFRecord': TFRecords_train_seq},
305-
'train_bichrom': {'seq': 'seq',
306-
'labels': 'labels',
307-
'chromatin_tracks': args.chromtracks,
308-
'tf_bam': args.tfbam,
309-
'fasta': args.fa,
310-
'nbins': args.nbins,
311-
'TFRecord': TFRecords_train_bichrom,
312306
'scaler_mean': chroms_scaler.mean_.tolist(),
313-
'scaler_var': chroms_scaler.var_.tolist()},
314-
'val': {'seq': 'seq',
315-
'labels': 'labels',
316-
'chromatin_tracks': args.chromtracks,
317-
'tf_bam': args.tfbam,
318-
'fasta': args.fa,
319-
'nbins': args.nbins,
320-
'TFRecord': TFRecords_val},
321-
'test': {'seq': 'seq',
322-
'labels': 'labels',
323-
'chromatin_tracks': args.chromtracks,
324-
'tf_bam': args.tfbam,
325-
'fasta': args.fa,
326-
'nbins': args.nbins,
327-
'TFRecord': TFRecords_test}}
328-
329-
logging.info("Indexing TFRecord files...")
330-
for name, dspath in yml_training_schema.items():
331-
tfrecords = dspath['TFRecord']
332-
tfrecord_idxs = [i.replace("TFRecord", "idx") for i in tfrecords]
333-
tfrecord2idx_script = "tfrecord2idx"
334-
335-
for index, tfrecord in enumerate(tfrecords):
336-
tfrecord_idx = tfrecord_idxs[index]
337-
if not os.path.isfile(tfrecord_idx):
338-
call([tfrecord2idx_script, tfrecord, tfrecord_idx])
339-
340-
dspath['TFRecord_idx'] = tfrecord_idxs
307+
'scaler_var': chroms_scaler.var_.tolist()
308+
},
309+
'train_seq': {
310+
'bed': train_coords_seq_bed,
311+
'webdataset': wds_train_seq
312+
},
313+
'train_bichrom': {
314+
'bed': train_coords_bichrom_bed,
315+
'webdataset': wds_train_bichrom
316+
},
317+
'val': {
318+
'bed': val_coords_bed,
319+
'webdataset': wds_val
320+
},
321+
'test': {
322+
'bed': test_coords_bed,
323+
'webdataset': wds_test
324+
}}
341325

342326
# Note: The x.split('/')[-1].split('.')[0] accounts for input chromatin bigwig files with
343327
# associated directory paths

construct_data/utils.py

Lines changed: 58 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from pybedtools import Interval, BedTool
1919
from sklearn.preprocessing import StandardScaler
2020

21-
import tensorflow as tf
21+
import webdataset as wds
2222

2323
def filter_chromosomes(input_df, to_filter=None, to_keep=None):
2424
"""
@@ -280,7 +280,7 @@ def get_data(coords, genome_fasta, chromatin_tracks, nbins, reverse=False, numPr
280280

281281
return X_seq, chromatin_out_lists, y
282282

283-
def get_data_TFRecord(coords, genome_fasta, chromatin_tracks, tf_bam, nbins, outprefix, reverse=False, numProcessors=1, chroms_scaler=None):
283+
def get_data_webdataset(coords, genome_fasta, chromatin_tracks, tf_bam, nbins, outprefix, reverse=False, compress=False, numProcessors=1, chroms_scaler=None):
284284
"""
285285
Given coordinates dataframe, extract the sequence and chromatin signal,
286286
Then save in **TFReocrd** format
@@ -293,13 +293,13 @@ def get_data_TFRecord(coords, genome_fasta, chromatin_tracks, tf_bam, nbins, out
293293
# freeze the common parameters
294294
## create a scaler to get statistics for normalizing chromatin marks input
295295
## also create a multiprocessing lock
296-
get_data_TFRecord_worker_freeze = functools.partial(get_data_TFRecord_worker,
296+
get_data_worker_freeze = functools.partial(get_data_webdataset_worker,
297297
fasta=genome_fasta, nbins=nbins,
298298
bigwig_files=chromatin_tracks, tf_bam=tf_bam,
299-
reverse=reverse)
299+
reverse=reverse, compress=compress)
300300

301301
pool = Pool(numProcessors)
302-
res = pool.starmap_async(get_data_TFRecord_worker_freeze, zip(chunks, [outprefix + "_" + str(i) for i in range(num_chunks)]))
302+
res = pool.starmap_async(get_data_worker_freeze, zip(chunks, [outprefix + "_" + str(i) for i in range(num_chunks)]))
303303
res = res.get()
304304

305305
# fit the scaler if provided
@@ -311,6 +311,59 @@ def get_data_TFRecord(coords, genome_fasta, chromatin_tracks, tf_bam, nbins, out
311311

312312
return files
313313

314+
def get_data_webdataset_worker(coords, outprefix, fasta, bigwig_files, tf_bam, nbins, reverse=False, compress=False):
315+
# get handlers
316+
genome_pyfasta = pyfasta.Fasta(fasta)
317+
bigwigs = [pyBigWig.open(bw) for bw in bigwig_files]
318+
tfbam = pysam.AlignmentFile(tf_bam)
319+
320+
# iterate all records
321+
filename = f"{outprefix}.tar.gz" if compress else f"{outprefix}.tar"
322+
sink = wds.TarWriter(filename, compress=compress)
323+
mss = []
324+
for item in coords.itertuples():
325+
feature_dict = defaultdict()
326+
feature_dict["__key__"] = f"{item.chrom}:{item.start}-{item.end}"
327+
328+
# seq
329+
seq = genome_pyfasta[item.chrom][int(item.start):int(item.end)]
330+
if reverse:
331+
seq = rev_comp(seq)
332+
seq_array = dna2onehot(seq)
333+
feature_dict["seq.npy"] = seq_array
334+
335+
#chromatin track
336+
ms = []
337+
try:
338+
for idx, bigwig in enumerate(bigwigs):
339+
m = (np.nan_to_num(bigwig.values(item.chrom, item.start, item.end))
340+
.reshape((nbins, -1))
341+
.mean(axis=1, dtype=np.float32))
342+
if reverse:
343+
m = m[::-1]
344+
ms.append(m)
345+
except RuntimeError as e:
346+
logging.warning(e)
347+
logging.warning(f"Chromatin track {bigwig_files[idx]} doesn't have information in {item} Skip this region...")
348+
continue
349+
ms = np.vstack(ms) # create the chromatin track array, shape (num_tracks, length)
350+
feature_dict["chrom.npy"] = ms
351+
mss.append(ms)
352+
# label
353+
feature_dict["label.npy"] = np.array(item.label, dtype=np.int32)[np.newaxis]
354+
# counts
355+
target = tfbam.count(item.chrom, item.start, item.end)
356+
feature_dict["target.npy"] = np.array(target, dtype=np.float32)[np.newaxis]
357+
358+
sink.write(feature_dict)
359+
360+
sink.close()
361+
for bw in bigwigs: bw.close()
362+
363+
mss = np.hstack(mss).T
364+
365+
return filename, mss
366+
314367
def get_data_TFRecord_worker(coords, outprefix, fasta, bigwig_files, tf_bam, nbins, reverse=False):
315368

316369
genome_pyfasta = pyfasta.Fasta(fasta)

0 commit comments

Comments
 (0)