88from nltk .corpus import stopwords
99from nltk .tokenize import word_tokenize
1010from nltk import pos_tag
11+ from src .qa .qa import TestRegistry
1112
1213model_name = "distilbert-base-uncased"
1314tokenizer = DistilBertTokenizer .from_pretrained (model_name )
1718nltk .download ("punkt" )
1819nltk .download ("averaged_perceptron_tagger" )
1920
20-
21+ @ TestRegistry . register ( "summary_length" )
2122class LengthTest (LLMQaTest ):
2223 @property
2324 def test_name (self ) -> str :
24- return "Summary Length Test "
25+ return "summary_length "
2526
2627 def get_metric (
2728 self , prompt : str , ground_truth : str , model_prediction : str
2829 ) -> Union [float , int , bool ]:
2930 return abs (len (ground_truth ) - len (model_prediction ))
3031
31-
32+ @ TestRegistry . register ( "jaccard_similarity" )
3233class JaccardSimilarityTest (LLMQaTest ):
3334 @property
3435 def test_name (self ) -> str :
35- return "Jaccard Similarity "
36+ return "jaccard_similarity "
3637
3738 def get_metric (
3839 self , prompt : str , ground_truth : str , model_prediction : str
@@ -46,11 +47,11 @@ def get_metric(
4647 similarity = intersection_size / union_size if union_size != 0 else 0
4748 return similarity
4849
49-
50+ @ TestRegistry . register ( "dot_product" )
5051class DotProductSimilarityTest (LLMQaTest ):
5152 @property
5253 def test_name (self ) -> str :
53- return "Semantic Similarity "
54+ return "dot_product "
5455
5556 def _encode_sentence (self , sentence ):
5657 tokens = tokenizer (sentence , return_tensors = "pt" )
@@ -68,11 +69,11 @@ def get_metric(
6869 )
6970 return dot_product_similarity
7071
71-
72+ @ TestRegistry . register ( "rouge_score" )
7273class RougeScoreTest (LLMQaTest ):
7374 @property
7475 def test_name (self ) -> str :
75- return "Rouge Score "
76+ return "rouge_score "
7677
7778 def get_metric (
7879 self , prompt : str , ground_truth : str , model_prediction : str
@@ -81,11 +82,11 @@ def get_metric(
8182 scores = scorer .score (model_prediction , ground_truth )
8283 return float (scores ["rouge1" ].precision )
8384
84-
85+ @ TestRegistry . register ( "word_overlap" )
8586class WordOverlapTest (LLMQaTest ):
8687 @property
8788 def test_name (self ) -> str :
88- return "Word Overlap Test "
89+ return "word_overlap "
8990
9091 def _remove_stopwords (self , text : str ) -> str :
9192 stop_words = set (stopwords .words ("english" ))
@@ -115,11 +116,11 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
115116 total_words = len (text .split (" " ))
116117 return round (len (pos_words ) / total_words , 2 )
117118
118-
119+ @ TestRegistry . register ( "verb_percent" )
119120class VerbPercent (PosCompositionTest ):
120121 @property
121122 def test_name (self ) -> str :
122- return "Verb Composition "
123+ return "verb_percent "
123124
124125 def get_metric (
125126 self , prompt : str , ground_truth : str , model_prediction : str
@@ -128,22 +129,22 @@ def get_metric(
128129 model_prediction , ["VB" , "VBD" , "VBG" , "VBN" , "VBP" , "VBZ" ]
129130 )
130131
131-
132+ @ TestRegistry . register ( "adjective_percent" )
132133class AdjectivePercent (PosCompositionTest ):
133134 @property
134135 def test_name (self ) -> str :
135- return "Adjective Composition "
136+ return "adjective_percent "
136137
137138 def get_metric (
138139 self , prompt : str , ground_truth : str , model_prediction : str
139140 ) -> float :
140141 return self ._get_pos_percent (model_prediction , ["JJ" , "JJR" , "JJS" ])
141142
142-
143+ @ TestRegistry . register ( "noun_percent" )
143144class NounPercent (PosCompositionTest ):
144145 @property
145146 def test_name (self ) -> str :
146- return "Noun Composition "
147+ return "noun_percent "
147148
148149 def get_metric (
149150 self , prompt : str , ground_truth : str , model_prediction : str
0 commit comments