Skip to content

Commit 28ba6ac

Browse files
committed
Remove w2v file; Add corresponding import
1 parent f22514b commit 28ba6ac

File tree

2 files changed

+4
-5
lines changed

2 files changed

+4
-5
lines changed

app/Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ COPY word_freqs .
1616
COPY w2v .
1717
RUN pip3 install -r requirements.txt
1818
RUN python -m nltk.downloader wordnet
19+
RUN python -m nltk.downloader word2vec_sample
1920
RUN python -m nltk.downloader brown
2021
RUN python -m nltk.downloader stopwords
2122
RUN python -m nltk.downloader punkt

app/evaluation.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88
import numpy.linalg
99
from nltk.corpus import stopwords
1010
from nltk import word_tokenize
11+
from nltk.data import find
1112

13+
word2vec_sample = str(find('models/word2vec_sample/pruned.word2vec.txt'))
14+
w2v = gensim.models.KeyedVectors.load_word2vec_format(word2vec_sample, binary=False)
1215

1316
def evaluation_function(response, answer, params):
1417
"""
@@ -155,8 +158,6 @@ def sentence_similarity(response: str, answer: str):
155158
blen = pickle.load(fp)
156159
with open('word_freqs', 'rb') as fp:
157160
freqs = pickle.load(fp)
158-
with open('w2v', 'rb') as fp:
159-
w2v = pickle.load(fp)
160161

161162
def sencence_scores(common_words, sentence):
162163
scores = []
@@ -194,8 +195,6 @@ def preprocess_tokens(text: str):
194195

195196

196197
def sentence_similarity_mean_w2v(response: str, answer: str):
197-
with open('w2v', 'rb') as fp:
198-
w2v = pickle.load(fp)
199198
response = preprocess_tokens(response)
200199
answer = preprocess_tokens(answer)
201200
response_embeddings = [w2v[word] for word in response if w2v.has_index_for(word)]
@@ -206,7 +205,6 @@ def sentence_similarity_mean_w2v(response: str, answer: str):
206205
answer_vector = np.mean(answer_embeddings, axis=0)
207206
return float(
208207
np.dot(response_vector, answer_vector) / (np.linalg.norm(response_vector) * np.linalg.norm(answer_vector)))
209-
# TODO
210208

211209

212210
if __name__ == "__main__":

0 commit comments

Comments
 (0)