Skip to content

Commit b0a9bf3

Browse files
committed
correct current minor bug on evaluation_test function
1 parent ead6366 commit b0a9bf3

File tree

2 files changed

+23
-18
lines changed

2 files changed

+23
-18
lines changed

app/evaluation.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, mode='gpt', llama_version='3_1_8B', temperature=0.01, max_new
1919
self.huggingfacehub_api_token = os.getenv("HUGGINGFACE_AUTHORIZATION")
2020
self.endpoint_3_1_8B = os.getenv("LLAMA3_1_8B_ENDPOINT")
2121

22-
self.response_num_required = 0 #initialise it with 0
22+
#self.response_num_required = 0 #initialise it with 0
2323

2424

2525
def parse_input(input_data):
@@ -78,10 +78,10 @@ def recursive_evaluation(responses, answers, chain, parser):
7878
for ans in list(remaining_answers): # Convert set to list for iteration
7979
eval_result = chain.invoke({"word": res, "target": ans})
8080
eval_result_content = eval_result.content
81-
print("eval_result_content: ", eval_result_content) #TODO: debugging
81+
#print("eval_result_content: ", eval_result_content) #TODO: debugging
8282
similarity_result = parser.invoke(eval_result_content)
8383

84-
print("similarity_result: ", similarity_result, "; res: ", res, "; ans: ", ans) #TODO: debugging
84+
#print("similarity_result: ", similarity_result, "; res: ", res, "; ans: ", ans) #TODO: debugging
8585

8686
if similarity_result == "True":
8787
matched_word = ans
@@ -102,13 +102,11 @@ def evaluation_function(response, answer, param=None):
102102
"""Evaluates the given response against the answer using LLaMA 3 or GPT-4o."""
103103
start_time = time.process_time()
104104

105-
106-
107105
#split the response and answer into lists with semicolons
108106
response = parse_input(response)
109107
answer = parse_input(answer)
110108

111-
print("response: ", response, "; answer: ", answer, "; param: ", param) #TODO: debugging
109+
# print("response: ", response, "; answer: ", answer, "; param: ", param) #TODO: debugging
112110

113111

114112
# Ensure config is provided
@@ -185,14 +183,14 @@ def evaluation_function(response, answer, param=None):
185183
if not (isinstance(response, list) and all(isinstance(item, str) for item in response) and
186184
isinstance(answer, list) and all(isinstance(item, str) for item in answer)):
187185
return {"is_correct": False, "error": "Invalid input: response and answer must be lists of strings."}
188-
print("Valid Inputs received: response: ", response, "; answer: ", answer) #TODO: debugging
186+
# print("Valid Inputs received: response: ", response, "; answer: ", answer) #TODO: debugging
189187

190-
print("Starting recursive evaluation...") #TODO: debugging
188+
# print("Starting recursive evaluation...") #TODO: debugging
191189
is_correct, correct_answers, incorrect_answers = recursive_evaluation(response, answer, chain, parser)
192-
print("correct_answers: ", correct_answers, "; incorrect_answers: ", incorrect_answers) #TODO: debugging
190+
# print("correct_answers: ", correct_answers, "; incorrect_answers: ", incorrect_answers) #TODO: debugging
193191

194192
#check if student is inputting enough answers
195-
if len(response) < param.response_num_required:
193+
if len(response) < len(answer):
196194
is_correct = False
197195

198196
return {
@@ -210,8 +208,10 @@ def evaluation_function(response, answer, param=None):
210208
if __name__ == "__main__":
211209
custom_config = Param()
212210
print(evaluation_function(
213-
"speed,red", #response
214-
"red, velocity", #answer
211+
"Velocity",
212+
"Speed",
213+
# "speed,red", #response
214+
# "red, velocity", #answer
215215
custom_config
216216
))
217217

app/evaluation_tests.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,7 @@ def test_partial_match(self):
4343
"""Test if a response too short is marked incorrect."""
4444
response = "Density;Velocity;Viscosity"
4545
answer = "Density;Velocity;Viscosity;Length"
46-
47-
self.param.response_num_required = 4
4846
result = evaluation_function(response, answer, self.param)
49-
50-
self.param.response_num_required = 0
51-
5247
self.assertFalse(result.get("is_correct"))
5348

5449
def test_synonyms_match(self):
@@ -62,7 +57,7 @@ def test_synonyms_match(self):
6257

6358
def test_exact_match_requirement(self):
6459
"""Test enforcing exact match on keystrings."""
65-
response = "density;speed;viscosity;length"
60+
response = "density;velocity;viscosity;length"
6661
answer = "Density;Velocity;Viscosity;Length"
6762

6863
result = evaluation_function(response, answer, self.param)
@@ -86,6 +81,16 @@ def test_negation_handling(self):
8681
result = evaluation_function(response, answer, self.param)
8782

8883

84+
self.assertFalse(result.get("is_correct"))
85+
86+
def test_short_response(self):
87+
"""Test how the model handles negation."""
88+
response = "yellow"
89+
answer = "yellow,blue"
90+
91+
result = evaluation_function(response, answer, self.param)
92+
93+
8994
self.assertFalse(result.get("is_correct"))
9095

9196
def test_performance(self):

0 commit comments

Comments
 (0)