Skip to content

Commit 6eb4e80

Browse files
vimalmanoharnaxingyu
authored andcommitted
[scripts] allowed_durations computation in standalone script
1 parent cfda4cf commit 6eb4e80

File tree

1 file changed

+219
-0
lines changed

1 file changed

+219
-0
lines changed
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
#!/usr/bin/env python3
2+
3+
# Copyright 2017 Hossein Hadian
4+
# 2019 Facebook Inc. (Author: Vimal Manohar)
5+
# Apache 2.0
6+
7+
8+
""" This script generates a set of allowed lengths of utterances
9+
spaced by a factor (like 10%). This is useful for generating
10+
fixed-length chunks for chain training.
11+
"""
12+
13+
import argparse
14+
import os
15+
import sys
16+
import copy
17+
import math
18+
import logging
19+
20+
sys.path.insert(0, 'steps')
21+
import libs.common as common_lib
22+
23+
logger = logging.getLogger('libs')
24+
logger.setLevel(logging.INFO)
25+
handler = logging.StreamHandler()
26+
handler.setLevel(logging.INFO)
27+
formatter = logging.Formatter("%(asctime)s [%(pathname)s:%(lineno)s - "
28+
"%(funcName)s - %(levelname)s ] %(message)s")
29+
handler.setFormatter(formatter)
30+
logger.addHandler(handler)
31+
32+
def get_args():
33+
parser = argparse.ArgumentParser(description="""
34+
This script creates a list of allowed durations of utterances for flatstart
35+
LF-MMI training corresponding to input data directory 'data_dir' and writes
36+
it in two files in output directory 'dir':
37+
1) allowed_durs.txt -- durations are in seconds
38+
2) allowed_lengths.txt -- lengths are in number of frames
39+
40+
Both the allowed_durs.txt and allowed_lengths.txt are formatted to
41+
have one entry on each line. Examples are as follows:
42+
43+
$ echo data/train/allowed_lengths.txt
44+
414
45+
435
46+
468
47+
48+
$ echo data/train/allowed_durs.txt
49+
4.16
50+
4.37
51+
4.70
52+
53+
These files can then be used by a downstream script to perturb the
54+
utterances to these lengths.
55+
A perturbed data directory (created by a downstream script
56+
similar to utils/data/perturb_speed_to_allowed_lengths.py)
57+
that only contains utterances of these allowed durations,
58+
along with the corresponding allowed_lengths.txt are
59+
consumed by the e2e chain egs preparation script.
60+
See steps/nnet3/chain/e2e/get_egs_e2e.sh for how these are used.
61+
62+
See also:
63+
* egs/cifar/v1/image/get_allowed_lengths.py -- a similar script for OCR datasets
64+
* utils/data/perturb_speed_to_allowed_lengths.py --
65+
creates the allowed_lengths.txt AND perturbs the data directory
66+
""")
67+
parser.add_argument('factor', type=float, default=12,
68+
help='Spacing (in percentage) between allowed lengths. '
69+
'Can be 0, which means all seen lengths that are a multiple of '
70+
'frame_subsampling_factor will be allowed.')
71+
parser.add_argument('data_dir', type=str, help='path to data dir. Assumes that '
72+
'it contains the utt2dur file.')
73+
parser.add_argument('dir', type=str, help='We write the output files '
74+
'allowed_lengths.txt and allowed_durs.txt to this directory.')
75+
parser.add_argument('--coverage-factor', type=float, default=0.05,
76+
help="""Percentage of durations not covered from each
77+
side of duration histogram.""")
78+
parser.add_argument('--frame-shift', type=int, default=10,
79+
help="""Frame shift in milliseconds.""")
80+
parser.add_argument('--frame-length', type=int, default=25,
81+
help="""Frame length in milliseconds.""")
82+
parser.add_argument('--frame-subsampling-factor', type=int, default=3,
83+
help="""Chain frame subsampling factor.
84+
See steps/nnet3/chain/train.py""")
85+
args = parser.parse_args()
86+
return args
87+
88+
89+
def read_kaldi_mapfile(path):
90+
""" Read any Kaldi mapping file - like text, .scp files, etc.
91+
"""
92+
93+
m = {}
94+
with open(path, 'r', encoding='latin-1') as f:
95+
for line in f:
96+
line = line.strip(" \t\r\n")
97+
sp_pos = line.find(' ')
98+
key = line[:sp_pos]
99+
val = line[sp_pos+1:]
100+
m[key] = val
101+
return m
102+
103+
104+
def find_duration_range(utt2dur, coverage_factor):
105+
"""Given a list of utterance durations, find the start and end duration to cover
106+
107+
If we try to cover
108+
all durations which occur in the training set, the number of
109+
allowed lengths could become very large.
110+
111+
Returns
112+
-------
113+
start_dur: float
114+
end_dur: float
115+
"""
116+
durs = [float(val) for key, val in utt2dur.items()]
117+
durs.sort()
118+
to_ignore_dur = 0
119+
tot_dur = sum(durs)
120+
for d in durs:
121+
to_ignore_dur += d
122+
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
123+
start_dur = d
124+
break
125+
to_ignore_dur = 0
126+
for d in reversed(durs):
127+
to_ignore_dur += d
128+
if to_ignore_dur * 100.0 / tot_dur > coverage_factor:
129+
end_dur = d
130+
break
131+
if start_dur < 0.3:
132+
start_dur = 0.3 # a hard limit to avoid too many allowed lengths --not critical
133+
return start_dur, end_dur
134+
135+
136+
def get_allowed_durations(start_dur, end_dur, args):
137+
"""Given the start and end duration, find a set of
138+
allowed durations spaced by args.factor%. Also write
139+
out the list of allowed durations and the corresponding
140+
allowed lengths (in frames) on disk.
141+
142+
Returns
143+
-------
144+
allowed_durations: list of allowed durations (in seconds)
145+
"""
146+
147+
allowed_durations = []
148+
d = start_dur
149+
with open(os.path.join(args.dir, 'allowed_durs.txt'), 'w', encoding='latin-1') as durs_fp, \
150+
open(os.path.join(args.dir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as lengths_fp:
151+
while d < end_dur:
152+
length = int(d * 1000 - args.frame_length) / args.frame_shift + 1
153+
if length % args.frame_subsampling_factor != 0:
154+
length = (args.frame_subsampling_factor *
155+
(length // args.frame_subsampling_factor))
156+
d = (args.frame_shift * (length - 1.0)
157+
+ args.frame_length + args.frame_shift / 2) / 1000.0
158+
allowed_durations.append(d)
159+
durs_fp.write("{}\n".format(d))
160+
lengths_fp.write("{}\n".format(int(length)))
161+
d *= args.factor
162+
return allowed_durations
163+
164+
165+
def get_trivial_allowed_durations(utt2dur, args):
166+
lengths = list(set(
167+
[int(float(d) * 1000 - args.frame_length) / args.frame_shift + 1
168+
for key, d in utt2dur.items()]
169+
))
170+
lengths.sort()
171+
172+
allowed_durations = []
173+
with open(os.path.join(args.dir, 'allowed_durs.txt'), 'w', encoding='latin-1') as durs_fp, \
174+
open(os.path.join(args.dir, 'allowed_lengths.txt'), 'w', encoding='latin-1') as lengths_fp:
175+
for length in lengths:
176+
if length % args.frame_subsampling_factor != 0:
177+
length = (args.frame_subsampling_factor *
178+
(length // args.frame_subsampling_factor))
179+
d = (args.frame_shift * (length - 1.0)
180+
+ args.frame_length + args.frame_shift / 2) / 1000.0
181+
allowed_durations.append(d)
182+
durs_fp.write("{}\n".format(d))
183+
lengths_fp.write("{}\n".format(length))
184+
185+
assert len(allowed_durations) > 0
186+
start_dur = allowed_durations[0]
187+
end_dur = allowed_durations[-1]
188+
189+
logger.info("Durations in the range [{},{}] will be covered."
190+
"".format(start_dur, end_dur))
191+
logger.info("There will be {} unique allowed lengths "
192+
"for the utterances.".format(len(allowed_durations)))
193+
194+
return allowed_durations
195+
196+
197+
def main():
198+
args = get_args()
199+
utt2dur = read_kaldi_mapfile(os.path.join(args.data_dir, 'utt2dur'))
200+
201+
if args.factor == 0.0:
202+
get_trivial_allowed_durations(utt2dur, args)
203+
return
204+
205+
args.factor = 1.0 + args.factor / 100.0
206+
207+
start_dur, end_dur = find_duration_range(utt2dur, args.coverage_factor)
208+
logger.info("Durations in the range [{},{}] will be covered. "
209+
"Coverage rate: {}%".format(start_dur, end_dur,
210+
100.0 - args.coverage_factor * 2))
211+
logger.info("There will be {} unique allowed lengths "
212+
"for the utterances.".format(int(math.log(end_dur / start_dur)/
213+
math.log(args.factor))))
214+
215+
get_allowed_durations(start_dur, end_dur, args)
216+
217+
218+
if __name__ == '__main__':
219+
main()

0 commit comments

Comments
 (0)