Skip to content

Commit fbe36e7

Browse files
committed
the minimum length of answer is set in config which is a global class, need to set it back to 0 if the same config is being called elsewhere: self.config.response_num_required = 4
result = evaluation_function(response, answer, self.config) self.config.response_num_required = 0
1 parent 311b07d commit fbe36e7

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

app/evaluation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def __init__(self, mode='gpt', llama_version='3_1_8B', temperature=0.01, max_new
1919
self.huggingfacehub_api_token = os.getenv("HUGGINGFACE_AUTHORIZATION")
2020
self.endpoint_3_1_8B = os.getenv("LLAMA3_1_8B_ENDPOINT")
2121

22-
self.response_num_required = 3
22+
self.response_num_required = 0 #initialise it with 0
2323

2424
def setup_llm(config):
2525
"""Initialize the LLM model (GPT-4o or LLaMA 3) based on the given configuration."""
@@ -145,6 +145,7 @@ def evaluation_function(response, answer, config=None):
145145
#check if student is inputting enough answers
146146
if len(response) < config.response_num_required:
147147
is_correct = False
148+
148149
return {
149150
"is_correct": is_correct,
150151
"result": {
@@ -160,8 +161,8 @@ def evaluation_function(response, answer, config=None):
160161
if __name__ == "__main__":
161162
custom_config = Config()
162163
print(evaluation_function(
163-
["Density","Density","Density"], #response
164-
["Density","Viscosity","Length","Density","Gravity","Viscosity","Length"], #answer
164+
["speed"], #response
165+
["velocity"], #answer
165166
custom_config
166167
))
167168

app/evaluation_tests.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def test_partial_match(self):
4242

4343
self.config.response_num_required = 4
4444
result = evaluation_function(response, answer, self.config)
45+
self.config.response_num_required = 0
4546
self.assertFalse(result.get("is_correct"))
4647

4748

4849
def test_synonyms_match(self):
4950
"""Test if abbriviations are correctly identified."""
50-
response = ['speed']
51-
answer = ['velocity']
51+
response = ['velocity']
52+
answer = ['speed']
5253
result = evaluation_function(response, answer, self.config)
5354

5455
self.assertTrue(result.get("is_correct"))

0 commit comments

Comments
 (0)