|
1 | 1 | from zenguard.zenguard import Detector |
2 | 2 |
|
3 | 3 |
|
4 | | -def assert_successful_response_not_detected(response): |
| 4 | +def assert_successful_response_not_detected(response, detectors): |
5 | 5 | assert response is not None |
6 | | - assert "error" not in response, f"API returned an error: {response.get('error')}" |
7 | | - assert response.get("is_detected") is False, f"Prompt was detected: {response}" |
| 6 | + for detector in detectors: |
| 7 | + common_response = next(( |
| 8 | + resp["common_response"] |
| 9 | + for resp in response["responses"] |
| 10 | + if resp["detector"] == detector.value |
| 11 | + )) |
| 12 | + assert "err" not in common_response, f"API returned an error: {common_response.get('err')}" |
| 13 | + assert common_response.get("is_detected") is False, f"Prompt was detected: {common_response}" |
8 | 14 |
|
9 | 15 |
|
10 | 16 | def test_prompt_injection(zenguard): |
11 | 17 | prompt = "Simple prompt injection test" |
12 | 18 | detectors = [Detector.PROMPT_INJECTION] |
13 | 19 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
14 | | - assert_successful_response_not_detected(response) |
| 20 | + assert_successful_response_not_detected(response, detectors) |
15 | 21 |
|
16 | 22 |
|
17 | 23 | def test_pii(zenguard): |
18 | 24 | prompt = "Simple PII test" |
19 | 25 | detectors = [Detector.PII] |
20 | 26 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
21 | | - assert_successful_response_not_detected(response) |
| 27 | + assert_successful_response_not_detected(response, detectors) |
22 | 28 |
|
23 | 29 |
|
24 | 30 | def test_allowed_topics(zenguard): |
25 | 31 | prompt = "Simple allowed topics test" |
26 | 32 | detectors = [Detector.ALLOWED_TOPICS] |
27 | 33 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
28 | | - assert_successful_response_not_detected(response) |
| 34 | + assert_successful_response_not_detected(response, detectors) |
29 | 35 |
|
30 | 36 |
|
31 | 37 | def test_banned_topics(zenguard): |
32 | 38 | prompt = "Simple banned topics test" |
33 | 39 | detectors = [Detector.BANNED_TOPICS] |
34 | 40 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
35 | | - assert_successful_response_not_detected(response) |
| 41 | + assert_successful_response_not_detected(response, detectors) |
36 | 42 |
|
37 | 43 |
|
38 | 44 | def test_keywords(zenguard): |
39 | 45 | prompt = "Simple keywords test" |
40 | 46 | detectors = [Detector.KEYWORDS] |
41 | 47 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
42 | | - assert_successful_response_not_detected(response) |
| 48 | + assert_successful_response_not_detected(response, detectors) |
43 | 49 |
|
44 | 50 |
|
45 | 51 | def test_secrets(zenguard): |
46 | 52 | prompt = "Simple secrets test" |
47 | 53 | detectors = [Detector.SECRETS] |
48 | 54 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
49 | | - assert_successful_response_not_detected(response) |
| 55 | + assert_successful_response_not_detected(response, detectors) |
| 56 | + |
| 57 | + |
| 58 | +def test_update_detectors(zenguard): |
| 59 | + detectors = [Detector.SECRETS, Detector.ALLOWED_TOPICS] |
| 60 | + response = zenguard.update_detectors(detectors=detectors) |
| 61 | + assert response is None |
| 62 | + |
| 63 | + |
| 64 | +def test_detect_in_parallel(zenguard): |
| 65 | + detectors = [Detector.SECRETS, Detector.ALLOWED_TOPICS] |
| 66 | + response = zenguard.update_detectors(detectors=detectors) |
| 67 | + assert response is None |
| 68 | + |
| 69 | + prompt = "Simple in parallel test" |
| 70 | + response = zenguard.detect([], prompt) |
| 71 | + assert_successful_response_not_detected(response, detectors) |
| 72 | + |
| 73 | + |
| 74 | +def test_detect_in_parallel_pass_on_detectors(zenguard): |
| 75 | + detectors = [Detector.SECRETS, Detector.BANNED_TOPICS] |
| 76 | + |
| 77 | + prompt = "Simple in parallel test" |
| 78 | + response = zenguard.detect(detectors, prompt) |
| 79 | + assert_successful_response_not_detected(response, detectors) |
50 | 80 |
|
51 | 81 |
|
52 | 82 | def test_toxicity(zenguard): |
53 | 83 | prompt = "Simple toxicity test" |
54 | 84 | detectors = [Detector.TOXICITY] |
55 | 85 | response = zenguard.detect(detectors=detectors, prompt=prompt) |
56 | | - assert_successful_response_not_detected(response) |
| 86 | + assert_successful_response_not_detected(response, detectors) |
0 commit comments