Skip to content

Commit 5d3fa49

Browse files
authored
Merge pull request #99 from VivekSinghDS/add/test-case
updated changes
2 parents 89919a3 + e039a90 commit 5d3fa49

File tree

4 files changed

+54
-26
lines changed

4 files changed

+54
-26
lines changed

config.yml

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,15 @@ inference:
6969
use_cache: True
7070
do_sample: True
7171
top_p: 0.9
72-
temperature: 0.8
72+
temperature: 0.8
73+
74+
qa:
75+
llm_tests:
76+
- jaccard_similarity
77+
- dot_product
78+
- rouge_score
79+
- word_overlap
80+
- verb_percent
81+
- adjective_percent
82+
- noun_percent
83+
- summary_length

src/pydantic_models/config_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
# TODO: Refactor this into multiple files...
99
HfModelPath = str
1010

11+
class QaConfig(BaseModel):
12+
llm_tests: Optional[List[str]] = Field([], description = "list of tests that needs to be connected")
13+
1114

1215
class DataConfig(BaseModel):
1316
file_type: Literal["json", "csv", "huggingface"] = Field(

src/qa/qa.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from abc import ABC, abstractmethod
22
from typing import Union, List, Tuple, Dict
33
import pandas as pd
4-
from toolkit.src.ui.rich_ui import RichUI
4+
from src.ui.rich_ui import RichUI
55
import statistics
6+
from src.qa.qa_tests import *
67

78

89
class LLMQaTest(ABC):
@@ -17,15 +18,27 @@ def get_metric(
1718
) -> Union[float, int, bool]:
1819
pass
1920

21+
class QaTestRegistry:
22+
registry = {}
2023

21-
class LLMTestSuite:
22-
def __init__(
23-
self,
24-
tests: List[LLMQaTest],
25-
prompts: List[str],
26-
ground_truths: List[str],
27-
model_preds: List[str],
28-
) -> None:
24+
@classmethod
25+
def register(cls, *names):
26+
def inner_wrapper(wrapped_class):
27+
for name in names:
28+
cls.registry[name] = wrapped_class
29+
return wrapped_class
30+
return inner_wrapper
31+
32+
@classmethod
33+
def create_tests_from_list(cls, test_name: str) -> List[LLMQaTest]:
34+
return [cls.create_test(test) for test in test_names]
35+
36+
class LLMTestSuite():
37+
def __init__(self,
38+
tests:List[LLMQaTest],
39+
prompts:List[str],
40+
ground_truths:List[str],
41+
model_preds:List[str]) -> None:
2942

3043
self.tests = tests
3144
self.prompts = prompts

src/qa/qa_tests.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from nltk.corpus import stopwords
99
from nltk.tokenize import word_tokenize
1010
from nltk import pos_tag
11+
from src.qa.qa import TestRegistry
1112

1213
model_name = "distilbert-base-uncased"
1314
tokenizer = DistilBertTokenizer.from_pretrained(model_name)
@@ -17,22 +18,22 @@
1718
nltk.download("punkt")
1819
nltk.download("averaged_perceptron_tagger")
1920

20-
21+
@TestRegistry.register("summary_length")
2122
class LengthTest(LLMQaTest):
2223
@property
2324
def test_name(self) -> str:
24-
return "Summary Length Test"
25+
return "summary_length"
2526

2627
def get_metric(
2728
self, prompt: str, ground_truth: str, model_prediction: str
2829
) -> Union[float, int, bool]:
2930
return abs(len(ground_truth) - len(model_prediction))
3031

31-
32+
@TestRegistry.register("jaccard_similarity")
3233
class JaccardSimilarityTest(LLMQaTest):
3334
@property
3435
def test_name(self) -> str:
35-
return "Jaccard Similarity"
36+
return "jaccard_similarity"
3637

3738
def get_metric(
3839
self, prompt: str, ground_truth: str, model_prediction: str
@@ -46,11 +47,11 @@ def get_metric(
4647
similarity = intersection_size / union_size if union_size != 0 else 0
4748
return similarity
4849

49-
50+
@TestRegistry.register("dot_product")
5051
class DotProductSimilarityTest(LLMQaTest):
5152
@property
5253
def test_name(self) -> str:
53-
return "Semantic Similarity"
54+
return "dot_product"
5455

5556
def _encode_sentence(self, sentence):
5657
tokens = tokenizer(sentence, return_tensors="pt")
@@ -68,11 +69,11 @@ def get_metric(
6869
)
6970
return dot_product_similarity
7071

71-
72+
@TestRegistry.register("rouge_score")
7273
class RougeScoreTest(LLMQaTest):
7374
@property
7475
def test_name(self) -> str:
75-
return "Rouge Score"
76+
return "rouge_score"
7677

7778
def get_metric(
7879
self, prompt: str, ground_truth: str, model_prediction: str
@@ -81,11 +82,11 @@ def get_metric(
8182
scores = scorer.score(model_prediction, ground_truth)
8283
return float(scores["rouge1"].precision)
8384

84-
85+
@TestRegistry.register("word_overlap")
8586
class WordOverlapTest(LLMQaTest):
8687
@property
8788
def test_name(self) -> str:
88-
return "Word Overlap Test"
89+
return "word_overlap"
8990

9091
def _remove_stopwords(self, text: str) -> str:
9192
stop_words = set(stopwords.words("english"))
@@ -115,11 +116,11 @@ def _get_pos_percent(self, text: str, pos_tags: List[str]) -> float:
115116
total_words = len(text.split(" "))
116117
return round(len(pos_words) / total_words, 2)
117118

118-
119+
@TestRegistry.register("verb_percent")
119120
class VerbPercent(PosCompositionTest):
120121
@property
121122
def test_name(self) -> str:
122-
return "Verb Composition"
123+
return "verb_percent"
123124

124125
def get_metric(
125126
self, prompt: str, ground_truth: str, model_prediction: str
@@ -128,22 +129,22 @@ def get_metric(
128129
model_prediction, ["VB", "VBD", "VBG", "VBN", "VBP", "VBZ"]
129130
)
130131

131-
132+
@TestRegistry.register("adjective_percent")
132133
class AdjectivePercent(PosCompositionTest):
133134
@property
134135
def test_name(self) -> str:
135-
return "Adjective Composition"
136+
return "adjective_percent"
136137

137138
def get_metric(
138139
self, prompt: str, ground_truth: str, model_prediction: str
139140
) -> float:
140141
return self._get_pos_percent(model_prediction, ["JJ", "JJR", "JJS"])
141142

142-
143+
@TestRegistry.register("noun_percent")
143144
class NounPercent(PosCompositionTest):
144145
@property
145146
def test_name(self) -> str:
146-
return "Noun Composition"
147+
return "noun_percent"
147148

148149
def get_metric(
149150
self, prompt: str, ground_truth: str, model_prediction: str

0 commit comments

Comments
 (0)