Skip to content

Commit e0093c5

Browse files
author
Baur
authored
Merge pull request #39 from ZenGuard-AI/baur/branch
[feature] add async detect and reporting for prompt_injection detector
2 parents f29b87d + 2c5ca15 commit e0093c5

File tree

3 files changed

+139
-12
lines changed

3 files changed

+139
-12
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "zenguard"
3-
version = "0.1.15"
3+
version = "0.1.16"
44
description = "Fast production grade security for GenAI applications"
55
authors = ["ZenGuard Team <hello@zenguard.ai>"]
66
license = "MIT"

tests/zenguard_e2e_test.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from zenguard.zenguard import Detector
1+
from unittest.mock import Mock, patch
2+
3+
import httpx
4+
import pytest
5+
6+
from zenguard.zenguard import API_REPORT_PROMPT_INJECTIONS, Detector
27

38

49
def assert_successful_response_not_detected(response):
@@ -74,14 +79,14 @@ def test_update_detectors(zenguard):
7479
assert response is None
7580

7681

77-
def test_detect_in_parallel(zenguard):
82+
def test_detect_in_parallel_error_no_detectors(zenguard):
7883
detectors = [Detector.SECRETS, Detector.ALLOWED_TOPICS]
7984
response = zenguard.update_detectors(detectors=detectors)
8085
assert response is None
8186

8287
prompt = "Simple in parallel test"
83-
response = zenguard.detect([], prompt)
84-
assert "No detectors" in response["error"]
88+
with pytest.raises(ValueError):
89+
response = zenguard.detect([], prompt)
8590

8691

8792
def test_detect_in_parallel_pass_on_detectors(zenguard):
@@ -91,3 +96,50 @@ def test_detect_in_parallel_pass_on_detectors(zenguard):
9196
response = zenguard.detect(detectors, prompt)
9297
assert_detectors_response(response, detectors)
9398
assert "error" not in response
99+
100+
101+
def test_prompt_injection_async(zenguard):
102+
prompt = "Simple prompt injection test"
103+
detectors = [Detector.PROMPT_INJECTION]
104+
zenguard.detect_async(detectors=detectors, prompt=prompt)
105+
106+
107+
def test_detect_error_no_detectors(zenguard):
108+
prompt = "Simple prompt injection test"
109+
with pytest.raises(ValueError):
110+
zenguard.detect_async([], prompt)
111+
112+
113+
def test_report_with_valid_detector_and_days(zenguard):
114+
with patch("httpx.post") as mock_post:
115+
mock_response = Mock()
116+
# TODO(baur): Update this to the actual response
117+
mock_response.json.return_value = {"prompt_injections": 10}
118+
mock_post.return_value = mock_response
119+
120+
result = zenguard.report(detector=Detector.PROMPT_INJECTION, days=7)
121+
122+
assert result == {"prompt_injections": 10}
123+
mock_post_args, mock_post_kwargs = mock_post.call_args
124+
125+
# Assert only the relevant parts of the API call
126+
assert API_REPORT_PROMPT_INJECTIONS in mock_post_args[0]
127+
assert mock_post_kwargs["json"] == {"days": 7}
128+
129+
130+
def test_report_with_invalid_detector(zenguard):
131+
with pytest.raises(ValueError):
132+
zenguard.report(detector=Detector.PII, days=7)
133+
134+
135+
def test_report_with_request_error(zenguard):
136+
with patch("httpx.post") as mock_post:
137+
mock_post.side_effect = httpx.RequestError("Connection error")
138+
139+
with pytest.raises(RuntimeError) as exc_info:
140+
zenguard.report(detector=Detector.PROMPT_INJECTION)
141+
142+
assert (
143+
str(exc_info.value)
144+
== "An error occurred while making the request: Connection error"
145+
)

zenguard/zenguard.py

Lines changed: 82 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
visualization,
2020
)
2121

22+
API_REPORT_PROMPT_INJECTIONS = "v1/report/prompt_injections"
23+
2224

2325
class SupportedLLMs:
2426
CHATGPT = "chatgpt"
@@ -99,8 +101,11 @@ def __init__(self, config: ZenGuardConfig):
99101
raise ValueError(f"LLM {config.llm} is not supported")
100102

101103
def detect(self, detectors: list[Detector], prompt: str):
104+
"""
105+
Uses detectors to evaluate the prompt and return the results.
106+
"""
102107
if len(detectors) == 0:
103-
return {"error": "No detectors were provided"}
108+
raise ValueError("No detectors were provided")
104109

105110
url = self._backend
106111
if len(detectors) == 1:
@@ -118,15 +123,48 @@ def detect(self, detectors: list[Detector], prompt: str):
118123
timeout=20,
119124
)
120125
except httpx.RequestError as e:
121-
print(e)
122-
return {"error": str(e)}
123-
124-
if response.status_code != 200:
125-
print(response.json())
126-
return {"error": str(response.json())}
126+
raise RuntimeError(
127+
f"An error occurred while making the request: {str(e)}"
128+
) from e
129+
except httpx.HTTPStatusError as e:
130+
raise RuntimeError(
131+
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.json()}"
132+
) from e
127133

128134
return response.json()
129135

136+
def detect_async(self, detectors: list[Detector], prompt: str):
137+
"""
138+
Same as detect function but asynchroneous.
139+
"""
140+
if len(detectors) == 0:
141+
raise ValueError("No detectors were provided")
142+
143+
if detectors[0] != Detector.PROMPT_INJECTION:
144+
raise ValueError(
145+
"Only Prompt Injection detector is supported for async detection"
146+
)
147+
148+
url = self._backend + convert_detector_to_api(detectors[0]) + "_async"
149+
json = {"messages": [prompt]}
150+
151+
try:
152+
response = httpx.post(
153+
url,
154+
json=json,
155+
headers={"x-api-key": self._api_key},
156+
timeout=20,
157+
)
158+
response.raise_for_status()
159+
except httpx.RequestError as e:
160+
raise RuntimeError(
161+
f"An error occurred while making the request: {str(e)}"
162+
) from e
163+
except httpx.HTTPStatusError as e:
164+
raise RuntimeError(
165+
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.json()}"
166+
) from e
167+
130168
def _attack_zenguard(self, detector: Detector, attacks: list[str]):
131169
attacks = tqdm(attacks)
132170
for attack in attacks:
@@ -171,3 +209,40 @@ def update_detectors(self, detectors: list[Detector]):
171209

172210
if response.status_code != 200:
173211
return {"error": str(response.json())}
212+
213+
def report(self, detector: Detector, days: int = None):
214+
"""
215+
Get a report of the detections made by the detector in the last days.
216+
Days is optional and if not provided, it will return all the detections.
217+
Days is int and will give back the number of detections made in the last days.
218+
"""
219+
220+
if detector != Detector.PROMPT_INJECTION:
221+
raise ValueError(
222+
"Only Prompt Injection detector is currently supported for reports"
223+
)
224+
225+
json = {}
226+
if days:
227+
json = {"days": days}
228+
229+
url = self._backend + API_REPORT_PROMPT_INJECTIONS
230+
231+
try:
232+
response = httpx.post(
233+
url,
234+
json=json,
235+
headers={"x-api-key": self._api_key},
236+
timeout=20,
237+
)
238+
response.raise_for_status()
239+
except httpx.RequestError as e:
240+
raise RuntimeError(
241+
f"An error occurred while making the request: {str(e)}"
242+
) from e
243+
except httpx.HTTPStatusError as e:
244+
raise RuntimeError(
245+
f"Received an unexpected status code: {response.status_code}\nResponse content: {response.text}"
246+
) from e
247+
248+
return response.json()

0 commit comments

Comments
 (0)