-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathweak_summarizers.py
More file actions
320 lines (241 loc) · 11.3 KB
/
weak_summarizers.py
File metadata and controls
320 lines (241 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
'''
__version__="1.0"
__description__ = "Script defining summarizer classes for weak learners"
__copyright__= "© 2022 MASSACHUSETTS INSTITUTE OF TECHNOLOGY"
__disclaimer__="THE SOFTWARE/FIRMWARE IS PROVIDED TO YOU ON AN “AS-IS” BASIS."
__SPDX_License_Identifier__="BSD-2-Clause"
'''
#!/usr/bin/env python
#
# Imports
#
import numpy as np
from scipy.spatial.distance import cosine
#text rank
from summa.pagerank_weighted import pagerank_weighted_scipy as pagerank
from summa.commons import build_graph as summa_build_graph
from summa.commons import remove_unreachable_nodes as summa_remove_unreachable_nodes
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import pdist, squareform
import scipy
from statistics import median
import re
import random
# Define the label mappings for convenience
SUMM = 1
NOT_SUMM = 0
ABSTAIN = -1
#
# Helpers
#
#modify code here
#
# Summarizer Classes
#
class RandomSummarizer():
def __init__(self, summ_len=3):
self.summ_len = summ_len
return
def summarize(self, sentences):
summary_inds = random.sample(list(range(len(sentences))), min(len(sentences),self.summ_len))
return ([sentences[i] for i in summary_inds], summary_inds)
def get_labels(self, sentences):
""" return sentence-level labels for sentences predicted to be summary-sentences """
_, summ_indices = self.summarize(sentences)
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class BinaryNumbersSummarizer():
def __init__(self):
return
def find_numbers(self, input_string): #adapted from https://stackoverflow.com/questions/19859282/check-if-a-string-contains-a-number
return bool(re.search(r'\d', input_string))
def summarize(self, sentences):
summary_inds = [i for i,sent in enumerate(sentences) if self.find_numbers(sent)]
return ([sentences[i] for i in summary_inds], summary_inds)
def get_labels(self, sentences):
""" return sentence-level labels for sentences predicted to be summary-sentences """
summ_indices = self.summarize(sentences)[1]
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class NumbersSummarizer():
def __init__(self, summary_length=3):
self.summary_length = summary_length
return
def summarize(self, sentences):
numbers_counts = [len(re.findall(r'\d', sent)) for sent in sentences]
max_count = max(numbers_counts)
numbers_scores = [count/max_count if max_count>0 else 0 for count in numbers_counts]
summary_inds = [tup[0] for tup in sorted(enumerate(numbers_scores), key=lambda tup:tup[1])[-self.summary_length:]]
return ([sentences[i] for i in summary_inds], summary_inds)
def get_labels(self, sentences):
""" return sentence-level labels for sentences predicted to be summary-sentences """
summ_indices = self.summarize(sentences)[1]
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class BinarySentenceLengthSummarizer():
def __init__(self):
return
def summarize(self, sentences):
lengths = [len(sent.split()) for sent in sentences]
# lengths_sorted = sorted(enumerate(lengths), key=lambda tup:tup[1])
lengths_sorted = sorted(lengths)
half = int(len(sentences)/2.0)
# q1 = median(lengths_sorted[:len(sentences)])
# q3 = median(lengths_sorted[len(sentences):])
q1 = median(lengths_sorted[:half])
q3 = median(lengths_sorted[-half:])
summary_inds = [i for i in range(len(sentences)) if lengths[i]>=q1 and lengths[i]<=q3]
return ([sentences[i] for i in summary_inds], summary_inds)
def get_labels(self, sentences):
""" return sentence-level labels for sentences predicted to be summary-sentences """
summary, summ_indices = self.summarize(sentences)
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class KalimatDefaultSummarizer():
def __init__(self):
return
def get_labels(self, sentences, highlights):
""" return sentence-level labels for sentences predicted to be summary-sentences """
sentences = [sent.strip() for sent in sentences]
highlight_set = set([h.strip() for h in highlights])
#todo: closest pyrouge distance?
summ_indices = set()
for j in range(len(highlights)):
summ_indices.add(max([(i,evaluate_rouge([[sentences[i]]], [[highlights[j]]], rouge_args=[])) for i in range(len(sentences))], key=lambda tup:tup[1])[0])
# summ_indices = [i for i, e in enumerate(sentences) if e in highlight_set]
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class CentroidSentenceBertSummarizer():
def placeholder(): #modify code here
#modify code here
summary_indices = [s[0] for s in sentences_summary]
#modify code here
def get_labels(self, sentences, sentence_embeddings, limit_type='word', limit=100):
""" return sentence-level labels for sentences predicted to be summary-sentences """
_,summ_indices = self.summarize(sentences,sentence_embeddings,limit_type='word',limit=100)
sent_labels = {}
for i in range(len(sentences)):
if i in summ_indices:
sent_labels[i] = SUMM
else:
sent_labels[i] = NOT_SUMM
return sent_labels
class TextRankSentenceBertSummarizer():
''' based on CentroidSentenceVertSummarizer
also on TextRank package's summarize() function
Perform TextRank over all sentences in document
'''
def __init__(self,ratio=0.2):
self.ratio = ratio
def _get_similarity(self,sentence_1, sentence_2):
''' Get similarity bw sentence representations '''
v1 = self.sentence_embeddings[int(sentence_1)]
v2 = self.sentence_embeddings[int(sentence_2)]
similarity = 0.0
if np.count_nonzero(v1) != 0 and np.count_nonzero(v2) != 0:
similarity = ((1 - cosine(v1, v2)) + 1) / 2
return similarity
def _create_valid_graph(self,graph):
nodes = graph.nodes()
for i in range(len(nodes)):
for j in range(len(nodes)):
if i == j:
continue
edge = (nodes[i], nodes[j])
if graph.has_edge(edge):
graph.del_edge(edge)
graph.add_edge(edge, 1)
def _set_graph_edge_weights(self,graph):
''' Make graph edge weights based on sentence embedding features '''
for sentence_1 in graph.nodes():
for sentence_2 in graph.nodes():
edge = (sentence_1, sentence_2)
if sentence_1 != sentence_2 and not graph.has_edge(edge):
similarity = self._get_similarity(sentence_1, sentence_2)
if similarity != 0:
graph.add_edge(edge, similarity)
# Handles the case in which all similarities are zero.
# The resultant summary will consist of random sentences.
if all(graph.edge_weight(edge) == 0 for edge in graph.edges()):
self._create_valid_graph(graph)
def compute_pagerank(self, sentence_embeddings):
""" Compute pagerank scores over each sentence embedding in document """
self.sentence_embeddings = sentence_embeddings
graph = summa_build_graph([i for i in range(sentence_embeddings.shape[0])])
self._set_graph_edge_weights(graph)
# Remove all nodes with all edges weights equal to zero.
summa_remove_unreachable_nodes(graph)
# PageRank cannot be run in an empty graph.
if len(graph.nodes()) == 0:
return
# Ranks the tokens using the PageRank algorithm. Returns dict of sentence ind -> pagerank score
pagerank_scores = pagerank(graph)
pagerank_scores = {int(k): v for k, v in sorted(pagerank_scores.items(), key=lambda item: item[1],reverse=True)}
self.pagerank_scores = pagerank_scores
def summarize(self, sentences, sentence_embeddings):
""" Derive extractive summaries from sentences """
self.compute_pagerank(sentence_embeddings)
#sorted_pr_inds = [int(x) for x in self.pagerank_scores.keys()]
sorted_pr_inds = list(self.pagerank_scores.keys())
# Extract the most important sentences with the selected criterion.
# num_sent_to_extract = len(sentences)*self.ratio
num_sent_to_extract = int(len(sentences)*self.ratio)
summary_inds = sorted_pr_inds[:num_sent_to_extract]
# summary_sentences = sentences[:summary_inds]
summary_sentences = [sentences[i] for i in summary_inds]
return summary_sentences,summary_inds
def get_labels(self,sentence_embeddings):
""" return sentence-level labels for sentences predicted to be summary-sentences """
self.compute_pagerank(sentence_embeddings)
num_to_extract = sentence_embeddings.shape[0]*self.ratio
sorted_pr = {k:SUMM if i < num_to_extract else NOT_SUMM for i,(k,v) in enumerate(self.pagerank_scores.items())}
#add back in NOT_SUMM for nodes that were removed
sorted_pr.update({k:NOT_SUMM for k in range(len(sentence_embeddings)) if k not in sorted_pr})
return {k: v for k, v in sorted(sorted_pr.items(), key=lambda item: item[0])}
class SentenceIndexSummarizer():
''' A basic heuristic summarizer based on target indices (i.e. first sentence, first-three sentences, etc.)
'''
def __init__(self,indices):
if isinstance(indices,int): indices = [indices]
self.indices = indices
def summarize(self, sentences):
""" Derive extractive summaries from sentences """
# return sentences[indices],indices
# return sentences[self.indices],self.indices
## return [sentences[i] for i in self.indices],self.indices
indices = [i for i in self.indices if i < len(sentences)]
return [sentences[i] for i in indices],indices
def get_labels(self,sentences):
""" return sentence-level labels for sentences predicted to be summary-sentences """
labels = {}
for i,x in enumerate(sentences):
if i == 0:
labels[i] = SUMM
else:
labels[i] = ABSTAIN
return labels