Skip to content

Commit ec2c9b3

Browse files
committed
evaluation function updated with gpt4o, new test method is created called fitted evaluation test, other minor changed include a new enviroment requirement txt
1 parent b70bb1f commit ec2c9b3

File tree

5 files changed

+251
-380
lines changed

5 files changed

+251
-380
lines changed

app/evaluation.py

Lines changed: 151 additions & 221 deletions
Original file line numberDiff line numberDiff line change
@@ -1,224 +1,154 @@
1-
import pickle
2-
import string
31
import time
4-
5-
import gensim
6-
import numpy as np
7-
from nltk.corpus import stopwords
8-
from nltk import word_tokenize
9-
from nltk.data import find
10-
11-
word2vec_sample = str(find('models/word2vec_sample/pruned.word2vec.txt'))
12-
w2v = gensim.models.KeyedVectors.load_word2vec_format(word2vec_sample, binary=False)
13-
14-
def evaluation_function(response, answer, params):
15-
"""
16-
Function used to evaluate a student response.
17-
---
18-
The handler function passes three arguments to evaluation_function():
19-
20-
- `response` which are the answers provided by the student.
21-
- `answer` which are the correct answers to compare against.
22-
- `params` which are any extra parameters that may be useful,
23-
e.g., error tolerances.
24-
25-
The output of this function is what is returned as the API response
26-
and therefore must be JSON-encodable. It must also conform to the
27-
response schema.
28-
29-
Any standard python library may be used, as well as any package
30-
available on pip (provided it is added to requirements.txt).
31-
32-
The way you wish to structure you code (all in this function, or
33-
split into many) is entirely up to you. All that matters are the
34-
return types and that evaluation_function() is the main function used
35-
to output the evaluation response.
36-
"""
2+
import os
3+
import re
4+
import pandas as pd
5+
from langchain.schema.runnable import RunnableLambda
6+
from langchain_openai import ChatOpenAI
7+
from langchain.prompts import PromptTemplate
8+
from dotenv import load_dotenv
9+
10+
class Config:
11+
def __init__(self, mode='gpt', llama_version='3_1_8B', temperature=0.01, max_new_token=5):
12+
load_dotenv()
13+
14+
self.mode = mode # Options: 'gpt', 'llama3'
15+
self.llama_version = llama_version
16+
self.temperature = temperature
17+
self.max_new_token = max_new_token
18+
self.openai_api_key = os.getenv('OPENAI_API_KEY')
19+
self.huggingfacehub_api_token = os.getenv("HUGGINGFACE_AUTHORIZATION")
20+
self.endpoint_3_1_8B = os.getenv("LLAMA3_1_8B_ENDPOINT")
21+
22+
self.response_num_required = 3
23+
24+
def setup_llm(config):
25+
"""Initialize the LLM model (GPT-4o or LLaMA 3) based on the given configuration."""
26+
if config.mode == 'gpt':
27+
return ChatOpenAI(
28+
model="gpt-4o-mini",
29+
temperature=config.temperature,
30+
max_tokens=config.max_new_token,
31+
openai_api_key=config.openai_api_key
32+
)
33+
elif config.mode == 'llama3':
34+
from langchain_huggingface import HuggingFaceEndpoint
35+
return HuggingFaceEndpoint(
36+
endpoint_url=config.endpoint_3_1_8B,
37+
max_new_tokens=config.max_new_token,
38+
temperature=config.temperature,
39+
huggingfacehub_api_token=config.huggingfacehub_api_token
40+
)
41+
42+
43+
def parse_last_boolean(response):
44+
"""Extracts the last boolean value (True/False) from model response."""
45+
matches = re.findall(r'\b(true|false)\b', response, re.IGNORECASE)
46+
return matches[-1].capitalize() if matches else "Unsure"
47+
48+
def recursive_evaluation(responses, answers, chain, parser):
49+
"""Evaluates a list of responses against a list of answers."""
50+
results = []
51+
matched_pairs = [] # Store matched word pairs
52+
unmatched_responses = [] # Store unmatched responses
53+
remaining_answers = set(answers) # Use a set for faster removal
54+
55+
for res in responses:
56+
matched_word = None
57+
for ans in list(remaining_answers): # Convert set to list for iteration
58+
eval_result = chain.invoke({"word": res, "target": ans})
59+
eval_result_content = eval_result.content
60+
similarity_result = parser.invoke(eval_result_content)
61+
62+
if similarity_result == "True":
63+
matched_word = ans
64+
#matched_pairs.append((res, ans))
65+
matched_pairs.append(res) #we dong want exact answer being given in the feedback
66+
remaining_answers.discard(ans) # Ensure immediate removal
67+
break # Exit loop after first match
68+
69+
if matched_word:
70+
results.append(True)
71+
else:
72+
results.append(False)
73+
unmatched_responses.append(res)
74+
75+
return all(results), matched_pairs, unmatched_responses
76+
77+
def evaluation_function(response, answer, config=None):
78+
"""Evaluates the given response against the answer using LLaMA 3 or GPT-4o."""
3779
start_time = time.process_time()
38-
39-
# params of the form {'keystrings': ['keystring1', 'keystring2', ...]}
40-
# keystring of the form {'string':..., 'exact_match:False', 'should_contain:True', 'custom_feedback:None}
41-
if params is not None and "keystrings" in params:
42-
keystrings = params["keystrings"]
43-
problematic_keystring = None
44-
keystring_scores = []
45-
response_tokens = preprocess_tokens(response)
46-
for keystring_object in keystrings:
47-
# Unpack keystring object
48-
keystring = keystring_object['string']
49-
exact_match = keystring_object['exact_match'] if 'exact_match' in keystring_object else False
50-
should_contain = keystring_object['should_contain'] if 'should_contain' in keystring_object else True
51-
custom_feedback = keystring_object['custom_feedback'] if 'custom_feedback' in keystring_object else None
52-
keystring_tokens = preprocess_tokens(keystring)
53-
54-
# Sliding window matching
55-
window_size = len(keystring_tokens)
56-
i = 0
57-
max_score = 0
58-
while i + window_size <= len(response_tokens):
59-
response_substring = " ".join(response_tokens[i:i + window_size])
60-
score1 = sentence_similarity_mean_w2v(response_substring, keystring)
61-
score2, _, _ = sentence_similarity(response_substring, keystring)
62-
max_score = max(score1, score2, max_score)
63-
i += 1
64-
keystring_scores.append((keystring, max_score))
65-
66-
threshold = 0.75
67-
if exact_match is True:
68-
threshold = 0.99
69-
70-
if should_contain is True and max_score < threshold and problematic_keystring is None:
71-
problematic_keystring = keystring
72-
feedback = f"Cannot determine if the answer is correct. Please provide more information about '{problematic_keystring}'"
73-
74-
if should_contain is False and max_score > threshold and problematic_keystring is None:
75-
problematic_keystring = keystring
76-
feedback = f"Cannot determine if the answer is correct. Identified '{problematic_keystring}' in the answer, which was not expected."
77-
78-
if custom_feedback is not None:
79-
feedback = f"Cannot determine if the answer is correct. {custom_feedback}"
80-
81-
if problematic_keystring is not None:
82-
return {
83-
"is_correct": False,
84-
"result": {
85-
"response": response,
86-
"processing_time": time.process_time() - start_time,
87-
"keystring-scores": keystring_scores
88-
},
89-
"feedback": feedback
90-
}
91-
92-
w2v_similarity = sentence_similarity_mean_w2v(response, answer)
93-
94-
if w2v_similarity > 0.75:
95-
return {
96-
"is_correct": True,
97-
"result": {
98-
"response": response,
99-
"processing_time": time.process_time() - start_time,
100-
"method": "w2v",
101-
"similarity_value": w2v_similarity
102-
},
103-
"feedback": f"Confidence: {'%.3f'%(w2v_similarity)}%"
104-
}
105-
106-
else:
107-
similarity, response_scores, answer_scores = sentence_similarity(response, answer)
108-
dif = 0
109-
word = None
110-
for (resp_score, ans_score) in zip(response_scores, answer_scores):
111-
if ans_score[0] - resp_score[0] > dif:
112-
dif = ans_score[0] - resp_score[0]
113-
word = resp_score[1]
114-
115-
both_one_word = len(response.split(' ')) == 1 and len(answer.split(' ')) == 1
116-
more_info_msg = f'Please provide more information about {word}' if word is not None else ''
117-
feedback_msg = (
118-
"Incorrect" if both_one_word
119-
else f"Cannot determine if the answer is correct ({'%.3f'%(w2v_similarity)}% similarity). {more_info_msg}" )
120-
121-
return {
122-
"is_correct": False,
123-
"result": {
124-
"response": response,
125-
"processing_time": time.process_time() - start_time,
126-
"method": "BOW vector similarity",
127-
"similarity_value": w2v_similarity,
128-
"BOW_similarity_value": similarity,
129-
"problematic_word": word
130-
},
131-
"feedback": feedback_msg,
132-
}
133-
134-
135-
def word_information_content(word, blen, freqs):
136-
if word not in freqs:
137-
f = 0
138-
else:
139-
f = freqs[word]
140-
return 1 - (np.log(f + 1)) / (np.log(blen + 1))
141-
142-
143-
def word_similarity(word1, word2, w2v):
144-
if word1 == word2:
145-
return 1
146-
if not w2v.has_index_for(word1) or not w2v.has_index_for(word2):
147-
return 0
148-
return w2v.similarity(word1, word2)
149-
150-
151-
def sentence_similarity(response: str, answer: str):
152-
response = response.lower()
153-
answer = answer.lower()
154-
for punc in string.punctuation:
155-
response = response.replace(punc, ' ')
156-
answer = answer.replace(punc, ' ')
157-
response_words = response.split()
158-
answer_words = answer.split()
159-
all_words = list(set((response_words + answer_words)))
160-
161-
with open('brown_length', 'rb') as fp:
162-
blen = pickle.load(fp)
163-
with open('word_freqs', 'rb') as fp:
164-
freqs = pickle.load(fp)
165-
166-
def sencence_scores(common_words, sentence):
167-
scores = []
168-
for word in common_words:
169-
best_similarity = -1
170-
best_word = word
171-
for other_word in sentence:
172-
similarity = word_similarity(word, other_word, w2v)
173-
if similarity > best_similarity:
174-
best_similarity = similarity
175-
best_word = other_word
176-
scores.append(
177-
(best_similarity * word_information_content(word, blen, freqs) * word_information_content(best_word,
178-
blen, freqs),
179-
word))
180-
return scores
181-
182-
response_scores = sencence_scores(all_words, response_words)
183-
answer_scores = sencence_scores(all_words, answer_words)
184-
185-
resp_scores = response_scores.copy()
186-
ans_scores = answer_scores.copy()
187-
for idx in range(len(response_scores)):
188-
response_scores[idx] = response_scores[idx][0]
189-
answer_scores[idx] = answer_scores[idx][0]
190-
score = np.dot(response_scores, answer_scores) / (np.linalg.norm(response_scores) * np.linalg.norm(answer_scores))
191-
return score, resp_scores, ans_scores
192-
193-
194-
def preprocess_tokens(text: str):
195-
text = text.lower()
196-
to_remove = stopwords.words('english') + list(string.punctuation)
197-
tokens = [word for word in word_tokenize(text) if word not in to_remove]
198-
return tokens
199-
200-
201-
def sentence_similarity_mean_w2v(response: str, answer: str):
202-
response = preprocess_tokens(response)
203-
answer = preprocess_tokens(answer)
204-
response_embeddings = [w2v[word] for word in response if w2v.has_index_for(word)]
205-
answer_embeddings = [w2v[word] for word in answer if w2v.has_index_for(word)]
206-
if len(response_embeddings) == 0 or len(answer_embeddings) == 0:
207-
return 0
208-
response_vector = np.mean(response_embeddings, axis=0)
209-
answer_vector = np.mean(answer_embeddings, axis=0)
210-
return float(
211-
np.dot(response_vector, answer_vector) / (np.linalg.norm(response_vector) * np.linalg.norm(answer_vector)))
212-
213-
80+
81+
# Ensure config is provided
82+
if config is None:
83+
config = Config()
84+
85+
# Initialize LLM
86+
llm = setup_llm(config)
87+
88+
# Define prompt template
89+
prompt_template = PromptTemplate(
90+
template='''
91+
### Instruction:
92+
Determine if the 2 words are semantically similar. Provide one of the following responses:
93+
- "True" if the words are semantically the same.
94+
- "False" if the words are semantically different.
95+
96+
### Examples:
97+
Word1: "happy", Word2: "happy"
98+
Response: True
99+
100+
Word1: "happy", Word2: "joyful"
101+
Response: True
102+
103+
Word1: "cat", Word 2: "dog"
104+
Response: False
105+
106+
Word1: "bank", Word 2: "actor"
107+
Response: False
108+
109+
### Input:
110+
Word1:{target}, Word2:{word}
111+
112+
### Response:
113+
''',
114+
input_variables=["target", "word"]
115+
)
116+
117+
parser = RunnableLambda(parse_last_boolean)
118+
chain = prompt_template | llm
119+
120+
# Validate inputs
121+
if not (isinstance(response, list) and all(isinstance(item, str) for item in response) and
122+
isinstance(answer, list) and all(isinstance(item, str) for item in answer)):
123+
return {"is_correct": False, "error": "Invalid input: response and answer must be lists of strings."}
124+
125+
is_correct, correct_answers, incorrect_answers = recursive_evaluation(response, answer, chain, parser)
126+
#check if student is inputting enough answers
127+
if len(response) < config.response_num_required:
128+
is_correct = False
129+
return {
130+
"is_correct": is_correct,
131+
"result": {
132+
"response": {"correct": correct_answers, "incorrect": incorrect_answers},
133+
"processing_time": time.process_time() - start_time,
134+
"method": "LLM-based comparison"
135+
},
136+
"feedback": f"Correct answers: {correct_answers}. Incorrect answers: {incorrect_answers}."
137+
}
138+
139+
140+
# Example Usage
214141
if __name__ == "__main__":
215-
pass
216-
print(evaluation_function("Density, speed, Viscosity, Length", "Density, Velocity, Viscosity, Length", {'keystrings': [{"string": "density"}, {"string": "velocity", "exact_match": False, 'should_contain': False}, {"string": "viscosity"}, {"string": "length"}]}))
217-
print(evaluation_function("Molecules are made out of atoms", "Many atoms form a molecule", {'keystrings': [{'string': 'molecule'}, {'string': 'proton', 'exact_match': True}]}))
218-
219-
# File sizes / Location / Permissions
220-
# Clear everything including nltk. Test with small files.
221-
#
222-
# Confidence score for evaluations of answers, grouped by 'correct'/'incorrect' answers
223-
#
224-
142+
custom_config = Config()
143+
print(evaluation_function(
144+
["Density","Density","Density"], #response
145+
["Density","Viscosity","Length","Density","Gravity","Viscosity","Length"], #answer
146+
custom_config
147+
))
148+
149+
# print(evaluation_function(
150+
# "Molecules are made out of atoms",
151+
# "Many atoms form a molecule",
152+
# {'keystrings': [{'string': 'molecule'}, {'string': 'proton', 'exact_match': True}]},
153+
# custom_config
154+
# ))

0 commit comments

Comments
 (0)