11import unittest
2- from evaluation import evaluation_function , Config
2+ from evaluation import evaluation_function , Param
33
44class TestEvaluationFunction (unittest .TestCase ):
55 """
@@ -16,22 +16,22 @@ class TestEvaluationFunction(unittest.TestCase):
1616
1717 @classmethod
1818 def setUpClass (cls ):
19- """Initialize a shared Config instance for LLM setup."""
20- cls .config = Config ()
19+ """Initialize a shared Param instance for LLM setup."""
20+ cls .param = Param ()
2121
2222 def test_basic_correct_response (self ):
2323 """Test if semantically similar responses are marked correct."""
2424 response = ["Density" , "Velocity" , "Viscosity" , "Length" ]
2525 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
26- result = evaluation_function (response , answer , self .config )
26+ result = evaluation_function (response , answer , self .param )
2727
2828 self .assertTrue (result .get ("is_correct" ))
2929
3030 def test_basic_incorrect_response (self ):
3131 """Test if semantically different responses are marked incorrect."""
3232 response = ["Mass" , "Speed" , "Friction" , "Force" ]
3333 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
34- result = evaluation_function (response , answer , self .config )
34+ result = evaluation_function (response , answer , self .param )
3535
3636 self .assertFalse (result .get ("is_correct" ))
3737
@@ -40,9 +40,9 @@ def test_partial_match(self):
4040 response = ["Density" , "Velocity" , "Viscosity" ]
4141 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
4242
43- self .config .response_num_required = 4
44- result = evaluation_function (response , answer , self .config )
45- self .config .response_num_required = 0
43+ self .param .response_num_required = 4
44+ result = evaluation_function (response , answer , self .param )
45+ self .param .response_num_required = 0
4646
4747 self .assertFalse (result .get ("is_correct" ))
4848
@@ -51,7 +51,7 @@ def test_synonyms_match(self):
5151 """Test if abbriviations are correctly identified."""
5252 response = ['velocity' ]
5353 answer = ['speed' ]
54- result = evaluation_function (response , answer , self .config )
54+ result = evaluation_function (response , answer , self .param )
5555
5656 self .assertTrue (result .get ("is_correct" ))
5757
@@ -60,15 +60,15 @@ def test_exact_match_requirement(self):
6060 response = ["density" , "speed" , "viscosity" , "length" ]
6161 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
6262
63- result = evaluation_function (response , answer , self .config )
63+ result = evaluation_function (response , answer , self .param )
6464 self .assertTrue (result .get ("is_correct" ))
6565
6666 def test_should_not_contain (self ):
6767 """Test if a response with a prohibited keyword fails."""
6868 response = ["density" , "velocity" , "viscosity" , "length" , "direction" ]
6969 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
7070
71- result = evaluation_function (response , answer , self .config )
71+ result = evaluation_function (response , answer , self .param )
7272 self .assertFalse (result .get ("is_correct" ))
7373
7474
@@ -77,7 +77,7 @@ def test_negation_handling(self):
7777 response = ["not light blue" , "dark blue" ]
7878 answer = ["light blue" ]
7979
80- result = evaluation_function (response , answer , self .config )
80+ result = evaluation_function (response , answer , self .param )
8181
8282 self .assertFalse (result .get ("is_correct" ))
8383
@@ -86,7 +86,7 @@ def test_performance(self):
8686 response = ["Density" , "Velocity" , "Viscosity" , "Length" ]
8787 answer = ["Density" , "Velocity" , "Viscosity" , "Length" ]
8888
89- result = evaluation_function (response , answer , self .config )
89+ result = evaluation_function (response , answer , self .param )
9090 processing_time = result .get ("result" , {}).get ("processing_time" , 0 )
9191
9292 self .assertLess (processing_time , 5 , msg = "Evaluation function should run efficiently." )
0 commit comments