Skip to content

Commit bf99b45

Browse files
committed
General formating updates; no major changes from the version discussed with alex
1 parent 619a388 commit bf99b45

File tree

3 files changed

+108
-31
lines changed

3 files changed

+108
-31
lines changed

app/compare_text_lists.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
def process_list(input_list):
2+
"""
3+
Detects if the input is a list, and if any element in the list contains semicolons,
4+
it splits that element into multiple elements.
5+
6+
Args:
7+
input_list (list): A list of strings.
8+
9+
Returns:
10+
list: A processed list where semicolon-separated elements are split into separate elements.
11+
"""
12+
if not isinstance(input_list, list):
13+
raise ValueError("Input must be a list of strings.")
14+
15+
processed_list = []
16+
for item in input_list:
17+
if not isinstance(item, str):
18+
raise ValueError("All elements in the input list must be strings.")
19+
20+
# Split by semicolon if present, otherwise keep the original item
21+
processed_list.extend(item.split(';') if ';' in item else [item])
22+
23+
return processed_list
24+
def test_process_list():
25+
"""
26+
Unit tests for process_list function.
27+
"""
28+
test_cases = [
29+
(["apple", "banana;orange", "grape"], ["apple", "banana", "orange", "grape"]),
30+
(["one;two;three", "four", "five"], ["one", "two", "three", "four", "five"]),
31+
(["alpha;beta", "gamma;delta;epsilon"], ["alpha", "beta", "gamma", "delta", "epsilon"]),
32+
(["no_separator"], ["no_separator"]),
33+
([], []),
34+
(["single"], ["single"]),
35+
]
36+
37+
for i, (input_list, expected_output) in enumerate(test_cases):
38+
assert process_list(input_list) == expected_output, f"Test case {i+1} failed"
39+
40+
print("All test cases passed!")
41+
42+
# Run the tests
43+
test_process_list()

app/evaluation.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from langchain.prompts import PromptTemplate
88
from dotenv import load_dotenv
99

10-
class Config:
10+
class Param:
1111
def __init__(self, mode='gpt', llama_version='3_1_8B', temperature=0.01, max_new_token=5):
1212
load_dotenv()
1313

@@ -21,22 +21,46 @@ def __init__(self, mode='gpt', llama_version='3_1_8B', temperature=0.01, max_new
2121

2222
self.response_num_required = 0 #initialise it with 0
2323

24-
def setup_llm(config):
24+
def compareTextLists(input_list):
25+
"""
26+
Detects if the input is a list, and if any element in the list contains semicolons,
27+
it splits that element into multiple elements.
28+
29+
Args:
30+
input_list (list): A list of strings.
31+
32+
Returns:
33+
list: A processed list where semicolon-separated elements are split into separate elements.
34+
"""
35+
if not isinstance(input_list, list):
36+
raise ValueError("Input must be a list of strings.")
37+
38+
processed_list = []
39+
for item in input_list:
40+
if not isinstance(item, str):
41+
raise ValueError("All elements in the input list must be strings.")
42+
43+
# Split by semicolon if present, otherwise keep the original item
44+
processed_list.extend(item.split(';') if ';' in item else [item])
45+
46+
return processed_list
47+
48+
def setup_llm(param):
2549
"""Initialize the LLM model (GPT-4o or LLaMA 3) based on the given configuration."""
26-
if config.mode == 'gpt':
50+
if param.mode == 'gpt':
2751
return ChatOpenAI(
2852
model="gpt-4o-mini",
29-
temperature=config.temperature,
30-
max_tokens=config.max_new_token,
31-
openai_api_key=config.openai_api_key
53+
temperature=param.temperature,
54+
max_tokens=param.max_new_token,
55+
openai_api_key=param.openai_api_key
3256
)
33-
elif config.mode == 'llama3':
57+
elif param.mode == 'llama3':
3458
from langchain_huggingface import HuggingFaceEndpoint
3559
return HuggingFaceEndpoint(
36-
endpoint_url=config.endpoint_3_1_8B,
37-
max_new_tokens=config.max_new_token,
38-
temperature=config.temperature,
39-
huggingfacehub_api_token=config.huggingfacehub_api_token
60+
endpoint_url=param.endpoint_3_1_8B,
61+
max_new_tokens=param.max_new_token,
62+
temperature=param.temperature,
63+
huggingfacehub_api_token=param.huggingfacehub_api_token
4064
)
4165

4266

@@ -74,16 +98,26 @@ def recursive_evaluation(responses, answers, chain, parser):
7498

7599
return all(results), matched_pairs, unmatched_responses
76100

77-
def evaluation_function(response, answer, config=None):
101+
def evaluation_function(response, answer, param=None):
78102
"""Evaluates the given response against the answer using LLaMA 3 or GPT-4o."""
103+
104+
105+
106+
107+
#split the response and answer into lists with semicolons
108+
response = compareTextLists(response)
109+
110+
111+
112+
79113
start_time = time.process_time()
80114

81-
# Ensure config is provided
82-
if config is None:
83-
config = Config()
115+
# Ensure param is provided
116+
if param is None:
117+
param = Param()
84118

85119
# Initialize LLM
86-
llm = setup_llm(config)
120+
llm = setup_llm(param)
87121

88122
# Define prompt template
89123
prompt_template = PromptTemplate(
@@ -143,7 +177,7 @@ def evaluation_function(response, answer, config=None):
143177

144178
is_correct, correct_answers, incorrect_answers = recursive_evaluation(response, answer, chain, parser)
145179
#check if student is inputting enough answers
146-
if len(response) < config.response_num_required:
180+
if len(response) < param.response_num_required:
147181
is_correct = False
148182

149183
return {
@@ -159,7 +193,7 @@ def evaluation_function(response, answer, config=None):
159193

160194
# Example Usage
161195
if __name__ == "__main__":
162-
custom_config = Config()
196+
custom_config = Param()
163197
print(evaluation_function(
164198
["speed"], #response
165199
["velocity"], #answer

app/evaluation_tests.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import unittest
2-
from evaluation import evaluation_function, Config
2+
from evaluation import evaluation_function, Param
33

44
class TestEvaluationFunction(unittest.TestCase):
55
"""
@@ -16,22 +16,22 @@ class TestEvaluationFunction(unittest.TestCase):
1616

1717
@classmethod
1818
def setUpClass(cls):
19-
"""Initialize a shared Config instance for LLM setup."""
20-
cls.config = Config()
19+
"""Initialize a shared Param instance for LLM setup."""
20+
cls.param = Param()
2121

2222
def test_basic_correct_response(self):
2323
"""Test if semantically similar responses are marked correct."""
2424
response = ["Density", "Velocity", "Viscosity", "Length"]
2525
answer = ["Density", "Velocity", "Viscosity", "Length"]
26-
result = evaluation_function(response, answer, self.config)
26+
result = evaluation_function(response, answer, self.param)
2727

2828
self.assertTrue(result.get("is_correct"))
2929

3030
def test_basic_incorrect_response(self):
3131
"""Test if semantically different responses are marked incorrect."""
3232
response = ["Mass", "Speed", "Friction", "Force"]
3333
answer = ["Density", "Velocity", "Viscosity", "Length"]
34-
result = evaluation_function(response, answer, self.config)
34+
result = evaluation_function(response, answer, self.param)
3535

3636
self.assertFalse(result.get("is_correct"))
3737

@@ -40,9 +40,9 @@ def test_partial_match(self):
4040
response = ["Density", "Velocity", "Viscosity"]
4141
answer = ["Density", "Velocity", "Viscosity", "Length"]
4242

43-
self.config.response_num_required = 4
44-
result = evaluation_function(response, answer, self.config)
45-
self.config.response_num_required = 0
43+
self.param.response_num_required = 4
44+
result = evaluation_function(response, answer, self.param)
45+
self.param.response_num_required = 0
4646

4747
self.assertFalse(result.get("is_correct"))
4848

@@ -51,7 +51,7 @@ def test_synonyms_match(self):
5151
"""Test if abbriviations are correctly identified."""
5252
response = ['velocity']
5353
answer = ['speed']
54-
result = evaluation_function(response, answer, self.config)
54+
result = evaluation_function(response, answer, self.param)
5555

5656
self.assertTrue(result.get("is_correct"))
5757

@@ -60,15 +60,15 @@ def test_exact_match_requirement(self):
6060
response = ["density", "speed", "viscosity", "length"]
6161
answer = ["Density", "Velocity", "Viscosity", "Length"]
6262

63-
result = evaluation_function(response, answer, self.config)
63+
result = evaluation_function(response, answer, self.param)
6464
self.assertTrue(result.get("is_correct"))
6565

6666
def test_should_not_contain(self):
6767
"""Test if a response with a prohibited keyword fails."""
6868
response = ["density", "velocity", "viscosity", "length", "direction"]
6969
answer = ["Density", "Velocity", "Viscosity", "Length"]
7070

71-
result = evaluation_function(response, answer, self.config)
71+
result = evaluation_function(response, answer, self.param)
7272
self.assertFalse(result.get("is_correct"))
7373

7474

@@ -77,7 +77,7 @@ def test_negation_handling(self):
7777
response = ["not light blue", "dark blue"]
7878
answer = ["light blue"]
7979

80-
result = evaluation_function(response, answer, self.config)
80+
result = evaluation_function(response, answer, self.param)
8181

8282
self.assertFalse(result.get("is_correct"))
8383

@@ -86,7 +86,7 @@ def test_performance(self):
8686
response = ["Density", "Velocity", "Viscosity", "Length"]
8787
answer = ["Density", "Velocity", "Viscosity", "Length"]
8888

89-
result = evaluation_function(response, answer, self.config)
89+
result = evaluation_function(response, answer, self.param)
9090
processing_time = result.get("result", {}).get("processing_time", 0)
9191

9292
self.assertLess(processing_time, 5, msg="Evaluation function should run efficiently.")

0 commit comments

Comments
 (0)